样本切分代码完成了

master
chenzenghui 2025-08-23 01:26:51 +08:00
parent cfed1d4638
commit d4b5539ce6
9 changed files with 1930 additions and 0 deletions

4
.gitignore vendored
View File

@ -713,3 +713,7 @@ FodyWeavers.xsd
hs_err_pid*
replay_pid*
# 其他数据处理程序
ProgramEXE/

View File

@ -0,0 +1,939 @@
"""
@Project microproduct
@File ImageHandle.py
@Function 实现对待处理SAR数据的读取格式标准化和处理完后保存文件功能
@Author LMM
@Date 2021/10/19 14:39
@Version 1.0.0
"""
import os
from PIL import Image
import time
from osgeo import gdal
from osgeo import osr
import numpy as np
from PIL import Image
import cv2
from xml.etree.ElementTree import ElementTree, Element
import math
class ImageHandler:
"""
影像读取编辑保存
"""
def __init__(self):
pass
@staticmethod
def get_dataset(filename):
"""
:param filename: tif路径
:return: 图像句柄
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
return dataset
def get_scope(self, filename):
"""
:param filename: tif路径
:return: 图像范围
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_scope = self.cal_img_scope(dataset)
del dataset
return im_scope
@staticmethod
def get_projection(filename):
"""
:param filename: tif路径
:return: 地图投影信息
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_proj = dataset.GetProjection()
del dataset
return im_proj
@staticmethod
def get_geotransform(filename):
"""
:param filename: tif路径
:return: 从图像坐标空间也称为像素线到地理参考坐标空间投影或地理坐标的仿射变换
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
geotransform = dataset.GetGeoTransform()
del dataset
return geotransform
def get_invgeotransform(filename):
"""
:param filename: tif路径
:return: 从地理参考坐标空间投影或地理坐标的到图像坐标空间
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
geotransform = dataset.GetGeoTransform()
geotransform=gdal.InvGeoTransform(geotransform)
del dataset
return geotransform
@staticmethod
def get_bands(filename):
"""
:param filename: tif路径
:return: 影像的波段数
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
bands = dataset.RasterCount
del dataset
return bands
@staticmethod
def geo2lonlat(dataset, x, y):
"""
将投影坐标转为经纬度坐标具体的投影坐标系由给定数据确定
:param dataset: GDAL地理数据
:param x: 投影坐标x
:param y: 投影坐标y
:return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
"""
prosrs = osr.SpatialReference()
prosrs.ImportFromWkt(dataset.GetProjection())
geosrs = prosrs.CloneGeogCS()
ct = osr.CoordinateTransformation(prosrs, geosrs)
coords = ct.TransformPoint(x, y)
return coords[:2]
@staticmethod
def get_band_array(filename, num=1):
"""
:param filename: tif路径
:param num: 波段序号
:return: 对应波段的矩阵数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
bands = dataset.GetRasterBand(num)
array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize)
# if 'int' in str(array.dtype):
# array[np.where(array == -9999)] = np.inf
# else:
# array[np.where(array < -9000.0)] = np.nan
del dataset
return array
@staticmethod
def write_imgArray(filename, im_data, no_data='null'):
"""
影像保存
:param filename: 保存的路径
:param im_proj:
:param im_geotrans:
:param im_data:
:param no_data: 把无效值设置为 nodata
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
if im_bands == 1:
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
outband = dataset.GetRasterBand(1)
outband.WriteArray(im_data)
if no_data != 'null':
outband.SetNoDataValue(no_data)
outband.FlushCache()
else:
for i in range(im_bands):
outband = dataset.GetRasterBand(1 + i)
outband.WriteArray(im_data[i])
outband.FlushCache()
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
del dataset
@staticmethod
def get_data(filename):
"""
:param filename: tif路径
:return: 获取所有波段的数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
del dataset
return im_data
@staticmethod
def get_scopes(ori_sim):
ori_sim_data = ImageHandler.get_data(ori_sim)
lon = ori_sim_data[0, :, :]
lat = ori_sim_data[1, :, :]
min_lon = np.nanmin(lon)
max_lon = np.nanmax(lon)
min_lat = np.nanmin(lat)
max_lat = np.nanmax(lat)
scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]]
return scopes
@staticmethod
def get_dataset(filename):
"""
:param filename: tif路径
:return: 获取所有波段的数据
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
return dataset
@staticmethod
def get_all_band_array(filename):
"""
大气延迟算法
将ERA-5影像所有波段存为一个数组, 波段数在第三维度 get_data->3788
:param filename 影像路径 get_all_band_array ->8837
:return: 影像数组
"""
dataset = gdal.Open(filename)
x_size = dataset.RasterXSize
y_size = dataset.RasterYSize
nums = dataset.RasterCount
array = np.zeros((y_size, x_size, nums), dtype=float)
if nums == 1:
bands_0 = dataset.GetRasterBand(1)
array = bands_0.ReadAsArray(0, 0, x_size, y_size)
else:
for i in range(0, nums):
bands = dataset.GetRasterBand(i+1)
arr = bands.ReadAsArray(0, 0, x_size, y_size)
array[:, :, i] = arr
return array
@staticmethod
def get_img_width(filename):
"""
:param filename: tif路径
:return: 影像宽度
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
width = dataset.RasterXSize
del dataset
return width
@staticmethod
def get_img_height(filename):
"""
:param filename: tif路径
:return: 影像高度
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
height = dataset.RasterYSize
del dataset
return height
@staticmethod
def read_img(filename):
"""
影像读取
:param filename:
:return:
"""
gdal.AllRegister()
img_dataset = gdal.Open(filename) # 打开文件
if img_dataset is None:
msg = 'Could not open ' + filename
print(msg)
return None, None, None
im_proj = img_dataset.GetProjection() # 地图投影信息
if im_proj is None:
return None, None, None
im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵
im_width = img_dataset.RasterXSize # 栅格矩阵的行数
im_height = img_dataset.RasterYSize # 栅格矩阵的行数
im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height)
del img_dataset
return im_proj, im_geotrans, im_arr
def cal_img_scope(self, dataset):
"""
计算影像的地理坐标范围
根据GDAL的六参数模型将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param dataset :GDAL地理数据
:return: list[point_upleft, point_upright, point_downleft, point_downright]
"""
if dataset is None:
return None
img_geotrans = dataset.GetGeoTransform()
if img_geotrans is None:
return None
width = dataset.RasterXSize # 栅格矩阵的列数
height = dataset.RasterYSize # 栅格矩阵的行数
point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0)
point_upright = self.trans_rowcol2geo(img_geotrans, width, 0)
point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height)
point_downright = self.trans_rowcol2geo(img_geotrans, width, height)
return [point_upleft, point_upright, point_downleft, point_downright]
@staticmethod
def get_scope_ori_sim(filename):
"""
计算影像的地理坐标范围
根据GDAL的六参数模型将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param dataset :GDAL地理数据
:return: list[point_upleft, point_upright, point_downleft, point_downright]
"""
gdal.AllRegister()
dataset = gdal.Open(filename)
if dataset is None:
return None
width = dataset.RasterXSize # 栅格矩阵的列数
height = dataset.RasterYSize # 栅格矩阵的行数
band1 = dataset.GetRasterBand(1)
array1 = band1.ReadAsArray(0, 0, band1.XSize, band1.YSize)
band2 = dataset.GetRasterBand(2)
array2 = band2.ReadAsArray(0, 0, band2.XSize, band2.YSize)
if array1[0, 0] < array1[0, width-1]:
point_upleft = [array1[0, 0], array2[0, 0]]
point_upright = [array1[0, width-1], array2[0, width-1]]
else:
point_upright = [array1[0, 0], array2[0, 0]]
point_upleft = [array1[0, width-1], array2[0, width-1]]
if array1[height-1, 0] < array1[height-1, width-1]:
point_downleft = [array1[height - 1, 0], array2[height - 1, 0]]
point_downright = [array1[height - 1, width - 1], array2[height - 1, width - 1]]
else:
point_downright = [array1[height - 1, 0], array2[height - 1, 0]]
point_downleft = [array1[height - 1, width - 1], array2[height - 1, width - 1]]
if(array2[0, 0] < array2[height - 1, 0]):
#上下调换顺序
tmp1 = point_upleft
point_upleft = point_downleft
point_downleft = tmp1
tmp2 = point_upright
point_upright = point_downright
point_downright = tmp2
return [point_upleft, point_upright, point_downleft, point_downright]
@staticmethod
def trans_rowcol2geo(img_geotrans,img_col, img_row):
"""
据GDAL的六参数模型仿射矩阵将影像图上坐标行列号转为投影坐标或地理坐标根据具体数据的坐标系统转换
:param img_geotrans: 仿射矩阵
:param img_col:图像纵坐标
:param img_row:图像横坐标
:return: [geo_x,geo_y]
"""
geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2] * img_row
geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5] * img_row
return [geo_x, geo_y]
@staticmethod
def write_era_into_img(filename, im_proj, im_geotrans, im_data):
"""
影像保存
:param filename:
:param im_proj:
:param im_geotrans:
:param im_data:
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_height, im_width, im_bands = im_data.shape # shape[0] 行数
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, i])
# dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
# 写GeoTiff文件
@staticmethod
def lat_lon_to_pixel(raster_dataset_path, location):
"""From zacharybears.com/using-python-to-translate-latlon-locations-to-pixels-on-a-geotiff/."""
gdal.AllRegister()
raster_dataset = gdal.Open(raster_dataset_path)
if raster_dataset is None:
return None
ds = raster_dataset
gt = ds.GetGeoTransform()
srs = osr.SpatialReference()
srs.ImportFromWkt(ds.GetProjection())
srs_lat_lon = srs.CloneGeogCS()
ct = osr.CoordinateTransformation(srs_lat_lon, srs)
new_location = [None, None]
# Change the point locations into the GeoTransform space
(new_location[1], new_location[0], holder) = ct.TransformPoint(location[1], location[0])
# Translate the x and y coordinates into pixel values
Xp = new_location[0]
Yp = new_location[1]
dGeoTrans = gt
dTemp = dGeoTrans[1] * dGeoTrans[5] - dGeoTrans[2] * dGeoTrans[4]
Xpixel = (dGeoTrans[5] * (Xp - dGeoTrans[0]) - dGeoTrans[2] * (Yp - dGeoTrans[3])) / dTemp
Yline = (dGeoTrans[1] * (Yp - dGeoTrans[3]) - dGeoTrans[4] * (Xp - dGeoTrans[0])) / dTemp
del raster_dataset
return (Xpixel, Yline)
@staticmethod
def write_img(filename, im_proj, im_geotrans, im_data, no_data='0'):
"""
影像保存
:param filename: 保存的路径
:param im_proj:
:param im_geotrans:
:param im_data:
:param no_data: 把无效值设置为 nodata
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
flag = False
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
flag = True
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
if flag:
outband = dataset.GetRasterBand(1)
outband.WriteArray(im_data[0])
if no_data != 'null':
outband.SetNoDataValue(np.double(no_data))
outband.FlushCache()
else:
outband = dataset.GetRasterBand(1)
outband.WriteArray(im_data)
if no_data != 'null':
outband.SetNoDataValue(np.double(no_data))
outband.FlushCache()
else:
for i in range(im_bands):
outband = dataset.GetRasterBand(1 + i)
outband.WriteArray(im_data[i])
if no_data != 'null':
outband.SetNoDataValue(np.double(no_data))
outband.FlushCache()
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
del dataset
# 写GeoTiff文件
@staticmethod
def write_img_envi(filename, im_proj, im_geotrans, im_data, no_data='null'):
"""
影像保存
:param filename: 保存的路径
:param im_proj:
:param im_geotrans:
:param im_data:
:param no_data: 把无效值设置为 nodata
:return:
"""
gdal_dtypes = {
'int8': gdal.GDT_Byte,
'unit16': gdal.GDT_UInt16,
'int16': gdal.GDT_Int16,
'unit32': gdal.GDT_UInt32,
'int32': gdal.GDT_Int32,
'float32': gdal.GDT_Float32,
'float64': gdal.GDT_Float64,
}
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
datatype = gdal_dtypes[im_data.dtype.name]
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
if os.path.exists(os.path.split(filename)[0]) is False:
os.makedirs(os.path.split(filename)[0])
driver = gdal.GetDriverByName("ENVI") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
outband = dataset.GetRasterBand(1)
outband.WriteArray(im_data)
if no_data != 'null':
outband.SetNoDataValue(no_data)
outband.FlushCache()
else:
for i in range(im_bands):
outband = dataset.GetRasterBand(1 + i)
outband.WriteArray(im_data[i])
outband.FlushCache()
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
del dataset
@staticmethod
def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict):
"""
图像中写入rpc信息
"""
# 判断栅格数据的数据类型
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_Int16
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
# 写入RPC参数
for k in rpc_dict.keys():
dataset.SetMetadataItem(k, rpc_dict[k], 'RPC')
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def transtif2mask(self,out_tif_path, in_tif_path, threshold):
"""
:param out_tif_path:输出路径
:param in_tif_path:输入的路径
:param threshold:阈值
"""
im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
im_arr_mask = (im_arr < threshold).astype(int)
self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)
def write_quick_view(self, tif_path, color_img=False, quick_view_path=None):
"""
生成快视图,默认快视图和影像同路径且同名
:param tif_path:影像路径
:param color_img:是否生成随机伪彩色图
:param quick_view_path:快视图路径
"""
if quick_view_path is None:
quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'
n = self.get_bands(tif_path)
if n == 1: # 单波段
t_data = self.get_data(tif_path)
else: # 多波段,转为强度数据
t_data = self.get_data(tif_path)
t_data = t_data.astype(float)
t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)
t_data[np.isnan(t_data)] = 0
t_data[np.where(t_data == -9999)] = 0
t_r = self.get_img_height(tif_path)
t_c = self.get_img_width(tif_path)
if t_r > 10000 or t_c > 10000:
q_r = int(t_r / 10)
q_c = int(t_c / 10)
elif 1024 < t_r < 10000 or 1024 < t_c < 10000:
if t_r > t_c:
q_r = 1024
q_c = int(t_c/t_r * 1024)
else:
q_c = 1024
q_r = int(t_r/t_c * 1024)
else:
q_r = t_r
q_c = t_c
if color_img is True:
# 生成伪彩色图
img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度)
u = np.unique(t_data)
for i in u:
if i != 0:
w = np.where(t_data == i)
img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
img[w[0], w[1], 1] = np.random.randint(0, 255)
img[w[0], w[1], 2] = np.random.randint(0, 255)
img = cv2.resize(img, (q_c, q_r)) # (宽,高)
cv2.imwrite(quick_view_path, img)
# cv2.imshow("result4", img)
# cv2.waitKey(0)
else:
# 灰度图
min = np.percentile(t_data, 2) # np.nanmin(t_data)
max = np.percentile(t_data, 98) # np.nanmax(t_data)
# if (max - min) < 256:
t_data = (t_data - min) / (max - min) * 255
out_img = Image.fromarray(t_data)
out_img = out_img.resize((q_c, q_r)) # 重采样
out_img = out_img.convert("L") # 转换成灰度图
out_img.save(quick_view_path)
def get_center_scopes(self, dataset):
if dataset is None:
return None
img_geotrans = dataset.GetGeoTransform()
if img_geotrans is None:
return None
width = dataset.RasterXSize # 栅格矩阵的列数
height = dataset.RasterYSize # 栅格矩阵的行数
x_split = int(width/5)
y_split = int(height/5)
img_col_start = x_split * 1
img_col_end = x_split * 3
img_row_start = y_split * 1
img_row_end = y_split *3
cols = img_col_end - img_col_start
rows = img_row_end - img_row_start
if cols > 10000 or rows > 10000:
img_col_end = img_col_start + 10000
img_row_end = img_row_start + 10000
point_upleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_start)
point_upright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_start)
point_downleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_end)
point_downright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_end)
return [point_upleft, point_upright, point_downleft, point_downright]
def write_view(self, tif_path, color_img=False, quick_view_path=None):
"""
生成快视图,默认快视图和影像同路径且同名
:param tif_path:影像路径
:param color_img:是否生成随机伪彩色图
:param quick_view_path:快视图路径
"""
if quick_view_path is None:
quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'
n = self.get_bands(tif_path)
if n == 1: # 单波段
t_data = self.get_data(tif_path)
else: # 多波段,转为强度数据
t_data = self.get_data(tif_path)
t_data = t_data.astype(float)
t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)
t_data[np.isnan(t_data)] = 0
t_data[np.where(t_data == -9999)] = 0
t_r = self.get_img_height(tif_path)
t_c = self.get_img_width(tif_path)
q_r = t_r
q_c = t_c
if color_img is True:
# 生成伪彩色图
img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度)
u = np.unique(t_data)
for i in u:
if i != 0:
w = np.where(t_data == i)
img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
img[w[0], w[1], 1] = np.random.randint(0, 255)
img[w[0], w[1], 2] = np.random.randint(0, 255)
img = cv2.resize(img, (q_c, q_r)) # (宽,高)
cv2.imwrite(quick_view_path, img)
# cv2.imshow("result4", img)
# cv2.waitKey(0)
else:
# 灰度图
min = np.percentile(t_data, 2) # np.nanmin(t_data)
max = np.percentile(t_data, 98) # np.nanmax(t_data)
# if (max - min) < 256:
t_data = (t_data - min) / (max - min) * 255
out_img = Image.fromarray(t_data)
out_img = out_img.resize((q_c, q_r)) # 重采样
out_img = out_img.convert("L") # 转换成灰度图
out_img.save(quick_view_path)
return quick_view_path
def limit_field(self, out_path, in_path, min_value, max_value):
"""
:param out_path:输出路径
:param in_path:主mask路径输出影像采用主mask的地理信息
:param min_value
:param max_value
"""
proj = self.get_projection(in_path)
geotrans = self.get_geotransform(in_path)
array = self.get_band_array(in_path, 1)
array[array < min_value] = min_value
array[array > max_value] = max_value
self.write_img(out_path, proj, geotrans, array)
return True
def band_merge(self, lon, lat, ori_sim):
lon_arr = self.get_data(lon)
lat_arr = self.get_data(lat)
temp = np.zeros((2, lon_arr.shape[0], lon_arr.shape[1]), dtype=float)
temp[0, :, :] = lon_arr[:, :]
temp[1, :, :] = lat_arr[:, :]
self.write_img(ori_sim, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], temp, '0')
def get_scopes(self, ori_sim):
ori_sim_data = self.get_data(ori_sim)
lon = ori_sim_data[0, :, :]
lat = ori_sim_data[1, :, :]
min_lon = np.nanmin(np.where((lon != 0) & ~np.isnan(lon), lon, np.inf))
max_lon = np.nanmax(np.where((lon != 0) & ~np.isnan(lon), lon, -np.inf))
min_lat = np.nanmin(np.where((lat != 0) & ~np.isnan(lat), lat, np.inf))
max_lat = np.nanmax(np.where((lat != 0) & ~np.isnan(lat), lat, -np.inf))
scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]]
return scopes
@staticmethod
def dem_merged(in_dem_path, out_dem_path):
'''
DEM重采样函数默认坐标系为WGS84
agrs:
in_dem_path: 输入的DEM文件夹路径
meta_file_path: 输入的xml元文件路径
out_dem_path: 输出的DEM文件夹路径
'''
# 读取文件夹中所有的DEM
dem_file_paths = [os.path.join(in_dem_path, dem_name) for dem_name in os.listdir(in_dem_path) if
dem_name.find(".tif") >= 0 and dem_name.find(".tif.") == -1]
spatialreference = osr.SpatialReference()
spatialreference.SetWellKnownGeogCS("WGS84") # 设置地理坐标,单位为度 degree # 设置投影坐标,单位为度 degree
spatialproj = spatialreference.ExportToWkt() # 导出投影结果
# 将DEM拼接成一张大图
mergeFile = gdal.BuildVRT(os.path.join(out_dem_path, "mergedDEM_VRT.tif"), dem_file_paths)
out_DEM = os.path.join(out_dem_path, "mergedDEM.tif")
gdal.Warp(out_DEM,
mergeFile,
format="GTiff",
dstSRS=spatialproj,
dstNodata=-9999,
outputType=gdal.GDT_Float32)
time.sleep(3)
# gdal.CloseDir(out_DEM)
return out_DEM
@staticmethod
def landcover_merged(in_dem_path, out_dem_path):
'''
DEM重采样函数默认坐标系为WGS84
agrs:
in_dem_path: 输入的DEM文件夹路径
meta_file_path: 输入的xml元文件路径
out_dem_path: 输出的DEM文件夹路径
'''
# 读取文件夹中所有的DEM
dem_file_paths = [os.path.join(in_dem_path, dem_name) for dem_name in os.listdir(in_dem_path) if
dem_name.find(".tif") >= 0 and dem_name.find(".tif.") == -1]
spatialreference = osr.SpatialReference()
spatialreference.SetWellKnownGeogCS("WGS84") # 设置地理坐标,单位为度 degree # 设置投影坐标,单位为度 degree
spatialproj = spatialreference.ExportToWkt() # 导出投影结果
# 将DEM拼接成一张大图
mergeFile = gdal.BuildVRT(os.path.join(out_dem_path, "mergedDEM_VRT.tif"), dem_file_paths)
out_DEM = os.path.join(out_dem_path, "mergedDEM.tif")
gdal.Warp(out_DEM,
mergeFile,
format="GTiff",
dstSRS=spatialproj,
dstNodata=-9999,
outputType=gdal.GDT_Byte)
time.sleep(3)
# gdal.CloseDir(out_DEM)
return out_DEM
@staticmethod
def get_inc_angle(inc_xml, rows, cols, out_path):
tree = ElementTree()
tree.parse(inc_xml) # 影像头文件
root = tree.getroot()
values = root.findall('incidenceValue')
angle_value = [value.text for value in values]
angle_value = np.array(angle_value)
inc_angle = np.tile(angle_value, (rows, 1))
ImageHandler.write_img(out_path, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], inc_angle)
pass
if __name__ == '__main__':
cols = 7086
rows = 8064
inc_xml = r"D:\micro\WorkSpace\Dem\Temporary\processing\product\GF3_SAY_FSI_001614_E113.2_N34.5_20161129_L1A_HHHV_L10002015686-DEM.tiff"
# ImageHandler().write_quick_view(inc_xml)
# fn = r'E:\202306hb\result\GF3B_SYC_QPSI_008316_E116.1_N43.3_20230622_L1A_AHV_L10000202892-cal-SMC.tif'
# out = r'E:\202306hb\result\soil.tif'
# #
# im_proj, im_geotrans, im_arr = ImageHandler.read_img(fn)
# im_arr[np.where(np.isnan(im_arr))] = 0
# # h,w = im_arr.shape
# # arr = np.random.rand(h,w)*0.4
# ImageHandler.write_img(out, im_proj, im_geotrans, im_arr, '-1')
# path = r'D:\BaiduNetdiskDownload\GZ\lon.rdr'
# path2 = r'D:\BaiduNetdiskDownload\GZ\lat.rdr'
# path3 = r'D:\BaiduNetdiskDownload\GZ\lon_lat.tif'
# s = ImageHandler().band_merge(path, path2, path3)
# print(s)
# pass
fileDir = r'D:\micro\WorkSpace\ortho\Temporary\VH_db'
outDir = r'D:\micro\WorkSpace\ortho\Temporary\VH_db\test'
files = os.listdir(fileDir)
for file in files:
tifFile = os.path.join(fileDir, file)
outFile = os.path.join(outDir, file)
im_proj, im_geotrans, im_arr = ImageHandler.read_img(tifFile)
im_arr[np.isnan(im_arr)] = 0
im_arr[np.isinf(im_arr)] = 0
ImageHandler.write_img(outFile, im_proj, im_geotrans, im_arr, '0')

View File

@ -0,0 +1,508 @@
import os
import argparse
from osgeo import ogr,gdal
from matplotlib import pyplot as plt
from osgeo import gdal
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.patches as patches
from osgeo import gdal
from PIL import Image
from scipy.spatial import cKDTree
import numpy as np
from DotaOperator import DotaObj,createDota,readDotaFile,writerDotaFile
import argparse
import math
from math import ceil, floor
##########################################################################
# 函数区
##########################################################################
def read_tif(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
del dataset # 关闭数据集
return im_proj, im_Geotrans, im_data
def write_envi(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径无需扩展名会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("ENVI")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
def geoXY2pixelXY(geo_x, geo_y, inv_gt):
pixel_x = inv_gt[0] + geo_x * inv_gt[1] + geo_y * inv_gt[2]
pixel_y = inv_gt[3] + geo_x * inv_gt[4] + geo_y * inv_gt[5]
return pixel_x, pixel_y
def label2pixelpoints(dotapath,tiff_inv_trans,methodstr):
dotalist = readDotaFile(dotapath)
if methodstr=="geolabel":
for i in range(len(dotalist)):
geo_x = dotalist[i].x1 # x1
geo_y = dotalist[i].y1
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x1 = pixel_x
dotalist[i].y1 = pixel_y
geo_x = dotalist[i].x2 # x2
geo_y = dotalist[i].y2
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x2 = pixel_x
dotalist[i].y2 = pixel_y
geo_x = dotalist[i].x3 # x3
geo_y = dotalist[i].y3
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x3 = pixel_x
dotalist[i].y3 = pixel_y
geo_x = dotalist[i].x4 # x4
geo_y = dotalist[i].y4
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x4 = pixel_x
dotalist[i].y4 = pixel_y
print("点数:", len(dotalist))
return dotalist
def getMaxEdge(dotalist, ids):
cornpoint = np.zeros((len(ids) * 4, 2))
for idx in range(len(ids)):
cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1
cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2
cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3
cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4
cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1
cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2
cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3
cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4
xedge = np.max(cornpoint[:, 0]) - np.min(cornpoint[:, 0])
yedge = np.max(cornpoint[:, 1]) - np.min(cornpoint[:, 1])
edgelen = xedge if xedge > yedge else yedge
return edgelen
def getExternCenter(dotalist, ids):
cornpoint = np.zeros((len(ids) * 4, 2))
for idx in range(len(ids)):
cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1
cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2
cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3
cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4
cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1
cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2
cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3
cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4
minX = np.min(cornpoint[:, 0])
minY = np.min(cornpoint[:, 1])
maxX = np.max(cornpoint[:, 0])
maxY = np.max(cornpoint[:, 1])
centerX = (minX + maxX) / 2
centerY = (minY + maxY) / 2
return [centerX, centerY, minX, minY, maxX, maxY]
def drawSliceRasterPrivew(tiff_data,dotalist,clusterDict):
# 绘制图形
# 创建图形和坐标轴
fig, ax = plt.subplots(figsize=(20, 16))
ax.imshow(tiff_data, cmap='gray')
# 绘制每个目标的矩形框并标注坐标
for i in range(len(dotalist)):
# 提取x和y坐标
x_coords = [dotalist[i].x1, dotalist[i].x2, dotalist[i].x3, dotalist[i].x4]
y_coords = [dotalist[i].y1, dotalist[i].y2, dotalist[i].y3, dotalist[i].y4]
# 计算最小外接矩形AABB
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
width = x_max - x_min
height = y_max - y_min
# 绘制无填充矩形框(仅红色边框)
rect = patches.Rectangle(
(x_min, y_min), width, height,
linewidth=2, edgecolor='red', facecolor='none' # 关键facecolor='none'
)
ax.add_patch(rect)
# ax.annotate(f'({x},{y})', xy=(x, y), xytext=(5, 5),
# textcoords='offset points', fontsize=10,
# bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8))
# 在矩形中心标注目标编号
center_x = sum(x_coords) / 4
center_y = sum(y_coords) / 4
ax.text(center_x, center_y, str(i),
ha='center', va='center', fontsize=6, color='red')
# 以类别中心为中心绘制四边形
for k in clusterDict:
# 绘制无填充矩形框(仅红色边框)
minX = clusterDict[k]["p"][0]
minY = clusterDict[k]["p"][1]
rect = patches.Rectangle(
(minX , minY), 1024, 1024,
linewidth=2, edgecolor='green', facecolor='none' # 关键facecolor='none'
)
ax.add_patch(rect)
ax.text(minX+512, minY+512, str(k),
ha='center', va='center', fontsize=6, color='green')
plt.tight_layout()
plt.show()
print("绘图结束")
return None
def find_optimal_slices(H, W, boxes, patch_size=1024, max_overlap_rate=0.2):
"""
Compute optimal slice positions for the image to maximize the number of fully contained rectangular patches (boxes),
while ensuring the overlap rate between any two slices does not exceed the specified maximum.
Parameters:
- H: Height of the image.
- W: Width of the image.
- boxes: List of tuples or lists, each containing (x1, y1, x2, y2) where (x1, y1) is the top-left and (x2, y2) is the bottom-right of a rectangular patch.
- patch_size: Size of each slice (square, e.g., 1024).
- max_overlap_rate: Maximum allowed overlap rate (e.g., 0.2).
Returns:
- slices: List of (sx, sy) starting positions for the slices.
- covered_count: Number of patches that appear fully in at least one slice.
"""
overlap_max = patch_size * max_overlap_rate
stride = patch_size - floor(overlap_max) # Ensures overlap <= max_overlap_rate in linear dimensions
N = len(boxes)
x_covered = [set() for _ in range(stride)]
y_covered = [set() for _ in range(stride)]
for i in range(N):
x1, y1, x2, y2 = boxes[i]
b_w = x2 - x1
b_h = y2 - y1
# For x-dimension
lx = max(0, x2 - patch_size)
rx = min(W - patch_size, x1)
if lx <= rx:
start_x = ceil(lx)
end_x = floor(rx)
l_x = end_x - start_x + 1
if l_x > 0:
if l_x >= stride:
for ox in range(stride):
x_covered[ox].add(i)
else:
for sx in range(start_x, end_x + 1):
ox = sx % stride
x_covered[ox].add(i)
# For y-dimension
ly = max(0, y2 - patch_size)
ry = min(H - patch_size, y1)
if ly <= ry:
start_y = ceil(ly)
end_y = floor(ry)
l_y = end_y - start_y + 1
if l_y > 0:
if l_y >= stride:
for oy in range(stride):
y_covered[oy].add(i)
else:
for sy in range(start_y, end_y + 1):
oy = sy % stride
y_covered[oy].add(i)
# Find the best offset pair (ox, oy) that maximizes covered patches
max_covered = 0
best_ox = 0
best_oy = 0
for ox in range(stride):
for oy in range(stride):
current_covered = len(x_covered[ox] & y_covered[oy])
if current_covered > max_covered:
max_covered = current_covered
best_ox = ox
best_oy = oy
# Generate the slice positions using the best offsets and stride
slices = []
sx = best_ox
while sx + patch_size <= W:
sy = best_oy
while sy + patch_size <= H:
slices.append((sx, sy))
sy += stride
sx += stride
return slices, max_covered
def check_B_in_A(A,B):
"""
判断A包含B
:param A: [x0,y0.w.h]
:param B: [x0,y0.w.h]
:return:
"""
# 解构矩形A和B的参数
Ax0, Ay0, Aw, Ah = A
Bx0, By0, Bw, Bh = B
# 计算矩形A和B的右边界和下边界
Ax1 = Ax0 + Aw
Ay1 = Ay0 + Ah
Bx1 = Bx0 + Bw
By1 = By0 + Bh
# 判断B是否完全在A内部
return (Bx0 >= Ax0) and (Bx1 <= Ax1) and (By0 >= Ay0) and (By1 <= Ay1)
##########################################################################
# 切分算法流程图
##########################################################################
def getclusterDict(dotalist,imgheight,imgwidth,pitchSize=1024,max_overlap_rate=0.2):
"""
生成切片数据
:param dotalist: 样本集
:param imgheight: 图像高度
:param imgwidth: 图像宽度
:return: 切片类型
"""
boxs=[]
for i in range(len(dotalist)):
xs=np.array([dotalist[i].x1,dotalist[i].x2,dotalist[i].x3, dotalist[i].x4])
ys=np.array([dotalist[i].y1,dotalist[i].y2,dotalist[i].y3, dotalist[i].y4])
x1=np.min(xs)
x2=np.max(xs)
y1=np.min(ys)
y2=np.max(ys)
boxs.append([x1,y1,x2,y2]) # x1, y1, x2, y2 = boxes[i]
slices, max_covered=find_optimal_slices(imgheight,imgwidth,boxs,pitchSize,max_overlap_rate)
clusterDict={}
waitContaindota=[]
hasContainIds=[]
for i in range(len(slices)):
sx,sy=slices[i]
clusterDict[i]={"p":[sx,sy],"id":[]}
slicesExten=[sx,sy,1024,1024]
for ids in range(len(dotalist)):
if ids in hasContainIds:
continue
else:
[centerX, centerY, minX, minY, maxX, maxY]=getExternCenter(dotalist, [ids])
dotaExtend=[minX,minY,maxX-minX,maxY-minY]
if check_B_in_A(slicesExten,dotaExtend):
clusterDict[i]["id"].append(ids)
hasContainIds.append(ids)
for ids in range(len(dotalist)):
if ids in hasContainIds:
continue
else:
waitContaindota.append(ids)
print("No in slice dota : ",str(dotalist[ids]) )
print("no process ids ",str(waitContaindota))
return clusterDict
def drawSlictplot(clusterDict,dotalist,tiff_data,nrows=10,ncols=9):
"""
:param clusterDict: clusterDict[i]={"p":[sx,sy],"id":[]}
:param dotalist: x1, y1, x2, y2, x3, y3, x4, y4 clsname diffcule
:return:
"""
fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(20, 16))
plt.tight_layout(pad=3.0)
# 9*10
subid=0
for cid in clusterDict:
sx,sy=clusterDict[cid]["p"]
colid=subid//nrows
rowid=subid%nrows
subid=subid+1
ax = axes[rowid, colid]
ax.set_title(str(cid))
sliceData=tiff_data[sy:(sy+1024),sx:(sx+1024)]
ax.imshow(sliceData, cmap='gray')
for did in clusterDict[cid]["id"] :
# 提取x和y坐标
x_coords = [dotalist[did].x1-sx, dotalist[did].x2-sx, dotalist[did].x3-sx, dotalist[did].x4-sx]
y_coords = [dotalist[did].y1-sy, dotalist[did].y2-sy, dotalist[did].y3-sy, dotalist[did].y4-sy]
# 计算最小外接矩形AABB
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
width = x_max - x_min
height = y_max - y_min
# 绘制无填充矩形框(仅红色边框)
rect = patches.Rectangle(
(x_min, y_min), width, height,
linewidth=2, edgecolor='red', facecolor='none' # 关键facecolor='none'
)
ax.add_patch(rect)
# 在矩形中心标注目标编号
center_x = x_min+width/2
center_y = y_min+height/2
ax.text(center_x, center_y, str(did),
ha='center', va='center', fontsize=6, color='red')
plt.tight_layout()
plt.show()
print("绘图结束")
return None
def slictDataAndOutlabel(clusterDict,dotalist,tiff_data,tiff_basename,outfolderpath,im_geotrans, im_proj):
"""
切分标签输出结果与文件
:param clusterDict:
:param dotalist:
:param tiff_data:
:param tiff_name:
:param outfolderpath:
:return:
"""
for cid in clusterDict:
sx, sy = clusterDict[cid]["p"]
sliceData = tiff_data[sy:(sy + 1024), sx:(sx + 1024)]
outbinname="{}_{}.bin".format(tiff_basename,cid)
outlabelname="{}_{}.txt".format(tiff_basename,cid)
# 获取样本列表
outdotalist=[]
for did in clusterDict[cid]["id"] :
tempdota=dotalist[did]
tempdota.x1=tempdota.x1-sx
tempdota.x2=tempdota.x2-sx
tempdota.x3=tempdota.x3-sx
tempdota.x4=tempdota.x4-sx
tempdota.y1=tempdota.y1-sy
tempdota.y2=tempdota.y2-sy
tempdota.y3=tempdota.y3-sy
tempdota.y4=tempdota.y4-sy
outdotalist.append(tempdota)
outlabelpath=os.path.join(outfolderpath,outlabelname)
outbinpath=os.path.join(outfolderpath,outbinname)
temp_im_geotrans=[tempi for tempi in im_geotrans]
# 处理 0,3
temp_im_geotrans[0]=im_geotrans[0]+sx*im_geotrans[1]+im_geotrans[2]*sy # x
temp_im_geotrans[3]=im_geotrans[3]+sx*im_geotrans[4]+im_geotrans[5]*sy # y
write_envi(sliceData,temp_im_geotrans,im_proj,outbinpath)
writerDotaFile(outdotalist,outlabelpath)
##########################################################################
# 处理流程图
##########################################################################
def DataSampleSliceRasterProcess(inbinfile,labelfilepath,outfolderpath,methodstr):
tiff_proj, tiff_trans, tiff_data = read_tif(inbinfile)
tiff_inv_trans = gdal.InvGeoTransform(tiff_trans)
dotalist=label2pixelpoints(labelfilepath,tiff_inv_trans,methodstr)
imgheight, imgwidth=tiff_data.shape
clusterDict=getclusterDict(dotalist,imgheight,imgwidth)
drawSliceRasterPrivew(tiff_data, dotalist, clusterDict)
ncols=int(len(clusterDict)/9+1)
drawSlictplot(clusterDict, dotalist, tiff_data, 9, ncols)
tiff_name=os.path.basename(inbinfile)
tiff_basename=os.path.splitext(tiff_name)[0]
slictDataAndOutlabel(clusterDict, dotalist, tiff_data, tiff_basename, outfolderpath, tiff_trans, tiff_proj)
pass
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--inbinfile',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.bin', help='输入tiff的bin文件')
parser.add_argument('-l', '--labelfilepath',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\标注\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01_LC.txt", help='输入标注')
parser.add_argument('-o', '--outfolderpath',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\切分3', help='切片文件夹地址')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--geolabel',
action='store_const',
const='geolabel',
dest='method',
help='标注坐标点为地理坐标'
)
group.add_argument(
'--pixellabel',
action='store_const',
const='pixellabel',
dest='method',
help='标注坐标系统为输入影像的像空间坐标'
)
parser.set_defaults(method='geolabel')
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = getParams()
inbinfile=parser.inbinfile
labelfilepath=parser.labelfilepath
outfolderpath=parser.outfolderpath
methodstr= parser.method
print('inbinfile=',inbinfile)
print('labelfilepath=',labelfilepath)
print('outfolderpath=',outfolderpath)
print('methodstr=',methodstr)
DataSampleSliceRasterProcess(inbinfile, labelfilepath, outfolderpath,methodstr)
print("样本切分完成")

View File

@ -0,0 +1,80 @@
import os
import argparse
"""
Dota数据集标注
每条标注数据应包含
+ 8 个数值表示目标的四个角点的坐标x1, y1, x2, y2, x3, y3, x4, y4按顺时针或逆时针顺序排列无需进行归一化处理
+倒数第二列为标注的目标类别名称jun_ship等
+最后一列为识别目标的难易程度difficulty
提供数据集标注
"""
class DotaObj(object):
def __init__(self, x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty):
self.x1=x1
self.y1=y1
self.x2=x2
self.y2=y2
self.x3=x3
self.y3=y3
self.x4=x4
self.y4=y4
self.clsname=clsname
self.difficulty=difficulty
def __str__(self):
return "{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}".format(
self.x1,self.y1,
self.x2,self.y2,
self.x3,self.y3,
self.x4,self.y4,
self.clsname,self.difficulty
)
def createDota(x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty):
return DotaObj(x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty) # 8+2
def readDotaFile(dotafilepath):
content=None
with open(dotafilepath,'r',encoding="utf-8") as fp:
content=fp.read()
contentlines=content.split("\n")
# 逐行分解
result=[]
for linestr in contentlines:
linestr=linestr.replace("\t"," ")
linestr=linestr.replace(" "," ")
linemetas=linestr.split(" ")
if(len(linemetas)>=10):
x1=float(linemetas[0])
y1=float(linemetas[1])
x2=float(linemetas[2])
y2=float(linemetas[3])
x3=float(linemetas[4])
y3=float(linemetas[5])
x4=float(linemetas[6])
y4=float(linemetas[7])
clsname=linemetas[8]
difficulty=linemetas[9]
result.append(createDota(x1, y1, x2, y2, x3, y3, x4, y4, clsname,difficulty))
else:
print("parse result: ", linestr)
return result
def writerDotaFile(dotalist,filepath):
with open(filepath,'a',encoding="utf-8") as fp:
for dota in dotalist:
if isinstance(dota,DotaObj):
fp.write("{}\n".format(str(dota)))
else:
fp.write("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}\n".format(
dota[0],dota[1],dota[2],dota[3],dota[4],dota[5],dota[6],dota[7],dota[8],dota[9]
))

View File

@ -0,0 +1,39 @@
from osgeo import ogr, gdal
import os
import argparse
def shp_to_geojson(shp_path, geojson_path):
# 设置编码选项
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
# 打开Shapefile
src_ds = ogr.Open(shp_path)
src_layer = src_ds.GetLayer(0)
# 创建输出GeoJSON
driver = ogr.GetDriverByName('GeoJSON')
if os.path.exists(geojson_path):
driver.DeleteDataSource(geojson_path)
dst_ds = driver.CreateDataSource(geojson_path)
dst_layer = dst_ds.CreateLayer('output', src_layer.GetSpatialRef())
# 复制字段定义
dst_layer.CreateFields(src_layer.schema)
# 复制要素
for feature in src_layer:
dst_feature = ogr.Feature(dst_layer.GetLayerDefn())
dst_feature.SetGeometry(feature.GetGeometryRef())
for j in range(feature.GetFieldCount()):
dst_feature.SetField(j, feature.GetField(j))
dst_layer.CreateFeature(dst_feature)
# 清理
dst_ds = None
src_ds = None
print("转换完成")
# 使用示例
shp_to_geojson('input.shp', 'output.geojson')

View File

@ -0,0 +1,220 @@
from osgeo import ogr, gdal
import os
import argparse
import numpy as np
from PIL import Image
def get_filename_without_ext(path):
base_name = os.path.basename(path)
if '.' not in base_name or base_name.startswith('.'):
return base_name
return base_name.rsplit('.', 1)[0]
def read_tif(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
del dataset # 关闭数据集
return im_proj, im_Geotrans, im_data
def write_envi(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径无需扩展名会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("ENVI")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
def Strech_linear(im_data):
im_data_dB=10*np.log10(im_data)
immask=np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data=im_data[immask]
im_data_dB=0
minvalue=np.nanmin(imvail_data)
maxvalue=np.nanmax(imvail_data)
infmask = np.isinf(im_data_dB)
im_data[infmask] = minvalue-100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254+1
im_data=np.clip(im_data,0,255)
return im_data.astype(np.uint8)
def Strech_linear1(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB=0
minvalue=np.percentile(imvail_data,1)
maxvalue = np.percentile(imvail_data, 99)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_linear2(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB = 0
minvalue = np.percentile(imvail_data, 2)
maxvalue = np.percentile(imvail_data, 98)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_linear5(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB = 0
minvalue = np.percentile(imvail_data, 5)
maxvalue = np.percentile(imvail_data, 95)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_SquareRoot(im_data):
im_data=np.sqrt(im_data)
immask = np.isfinite(im_data)
imvail_data = im_data[immask]
minvalue=np.nanmin(imvail_data)
maxvalue=np.nanmax(imvail_data)
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def DataStrech(im_data,strechmethod):
# [,"Linear1","Linear2","Linear5","SquareRoot"]
if strechmethod == "Linear" :
return Strech_linear(im_data)
elif strechmethod == "Linear1":
return Strech_linear1(im_data)
elif strechmethod == "Linear2":
return Strech_linear2(im_data)
elif strechmethod == "Linear5":
return Strech_linear5(im_data)
elif strechmethod == "SquareRoot":
return Strech_SquareRoot(im_data)
else:
return im_data.astype(np.uint8)
def stretchProcess(infilepath,outfilepath,strechmethod):
im_proj, im_Geotrans, im_data=read_tif(infilepath)
envifilepath=get_filename_without_ext(outfilepath)+".bin"
envifilepath=os.path.join(os.path.dirname(outfilepath),envifilepath)
im_data = DataStrech(im_data,strechmethod)
im_data = im_data.astype(np.uint8)
write_envi(im_data,im_Geotrans,im_proj,envifilepath)
Image.fromarray(im_data).save(outfilepath,compress_level=0)
print("图像拉伸处理结束")
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--infile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.tiff", help='输入shapefile文件')
parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.png", help='输出geojson文件')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--Linear',
action='store_const',
const='Linear',
dest='method',
help='线性拉伸'
)
group.add_argument(
'--Linear1prec',
action='store_const',
const='Linear1',
dest='method',
help='1%线性拉伸'
)
group.add_argument(
'--Linear2prec',
action='store_const',
const='Linear2',
dest='method',
help='2%线性拉伸'
)
group.add_argument(
'--Linear5prec',
action='store_const',
const='Linear5',
dest='method',
help='5%线性拉伸'
)
group.add_argument(
'--SquareRoot',
action='store_const',
const='SquareRoot',
dest='method',
help='平方根拉伸'
)
parser.set_defaults(method='SquareRoot')
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = getParams()
intiffPath=parser.infile
outbinPath=parser.outfile
methodstr=parser.method
print('infile=',intiffPath)
print('outfile=',outbinPath)
print('method=',methodstr)
stretchProcess(intiffPath, outbinPath, methodstr)

View File

@ -0,0 +1,89 @@
from osgeo import ogr
import os
import argparse
def shapefile_to_dota(shp_path, output_path, class_field='class', difficulty_value=1):
"""
将Shapefile转换为DOTA格式
:param shp_path: Shapefile文件路径
:param output_path: 输出目录
:param class_field: 类别字段名
:param difficulty_value: 难度默认字段
"""
# 注册所有驱动
ogr.RegisterAll()
# 打开Shapefile文件
driver = ogr.GetDriverByName('ESRI Shapefile')
datasource = driver.Open(shp_path, 0)
if datasource is None:
print("无法打开Shapefile文件")
return
# 获取图层
layer = datasource.GetLayer()
output_file = output_path
with open(output_file, 'w',encoding="utf-8") as f:
# 写入DOTA格式头信息(可选)
# f.write('imagesource:unknown\n')
# f.write('gsd:1.0\n')
# 遍历所有要素
for feature in layer:
# 获取几何对象
geom = feature.GetGeometryRef()
# 获取类别和难度
class_name = feature.GetField(class_field) if feature.GetField(class_field) else 'unknown'
difficulty = difficulty_value
# 处理不同类型的几何图形
if geom.GetGeometryName() == 'POLYGON':
# 获取多边形外环
ring = geom.GetGeometryRef(0)
# 获取所有点
points = []
for i in range(ring.GetPointCount()):
points.append(ring.GetPoint(i))
# 确保有足够的点(至少4个)
if len(points) >= 4:
# 取前4个点作为DOTA格式的四个角点
# 注意: DOTA要求按顺序排列(顺时针或逆时针)
x1, y1 = points[0][0], points[0][1]
x2, y2 = points[1][0], points[1][1]
x3, y3 = points[2][0], points[2][1]
x4, y4 = points[3][0], points[3][1]
# 写入DOTA格式行
line = f"{x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4} {class_name} {difficulty}\n"
f.write(line)
# 释放资源
datasource.Destroy()
print("转换完毕")
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--infile',type=str,default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.shp', help='输入shapefile文件')
parser.add_argument('-o', '--outfile',type=str,default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.txt', help='输出geojson文件')
parser.add_argument('-c', '--clsname',type=str,default=r'Label', help='输出geojson文件')
parser.add_argument('-d', '--difficulty',type=int,default=1, help='输出geojson文件')
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = getParams()
inFilePath=parser.infile
outpath=parser.outfile
clsname=parser.clsname
difficulty=parser.difficulty
print('infile=',inFilePath)
print('outfile=',outpath)
print('clsname=',clsname)
print('difficulty=',difficulty)
shapefile_to_dota(inFilePath, outpath, clsname, difficulty)
# 使用示例

View File

@ -0,0 +1,51 @@
from osgeo import ogr, gdal
import os
import argparse
def shp_to_geojson(shp_path, geojson_path):
# 设置编码选项
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
# 打开Shapefile
src_ds = ogr.Open(shp_path)
src_layer = src_ds.GetLayer(0)
# 创建输出GeoJSON
driver = ogr.GetDriverByName('GeoJSON')
if os.path.exists(geojson_path):
driver.DeleteDataSource(geojson_path)
dst_ds = driver.CreateDataSource(geojson_path)
dst_layer = dst_ds.CreateLayer('output', src_layer.GetSpatialRef())
# 复制字段定义
dst_layer.CreateFields(src_layer.schema)
# 复制要素
for feature in src_layer:
dst_feature = ogr.Feature(dst_layer.GetLayerDefn())
dst_feature.SetGeometry(feature.GetGeometryRef())
for j in range(feature.GetFieldCount()):
dst_feature.SetField(j, feature.GetField(j))
dst_layer.CreateFeature(dst_feature)
# 清理
dst_ds = None
src_ds = None
print("转换完成")
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--infile',default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.shp', help='输入shapefile文件')
parser.add_argument('-o', '--outfile',default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.json', help='输出geojson文件')
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = getParams()
inFilePath=parser.infile
outpath=parser.outfile
print('infile=',inFilePath)
print('outfile=',outpath)
shp_to_geojson(inFilePath, outpath)