样本切分代码完成了
parent
cfed1d4638
commit
d4b5539ce6
|
|
@ -713,3 +713,7 @@ FodyWeavers.xsd
|
||||||
hs_err_pid*
|
hs_err_pid*
|
||||||
replay_pid*
|
replay_pid*
|
||||||
|
|
||||||
|
# 其他数据处理程序
|
||||||
|
ProgramEXE/
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,939 @@
|
||||||
|
"""
|
||||||
|
@Project :microproduct
|
||||||
|
@File :ImageHandle.py
|
||||||
|
@Function :实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能
|
||||||
|
@Author :LMM
|
||||||
|
@Date :2021/10/19 14:39
|
||||||
|
@Version :1.0.0
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import time
|
||||||
|
from osgeo import gdal
|
||||||
|
from osgeo import osr
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
from xml.etree.ElementTree import ElementTree, Element
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageHandler:
|
||||||
|
"""
|
||||||
|
影像读取、编辑、保存
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 图像句柄
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def get_scope(self, filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 图像范围
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
im_scope = self.cal_img_scope(dataset)
|
||||||
|
del dataset
|
||||||
|
return im_scope
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_projection(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 地图投影信息
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
im_proj = dataset.GetProjection()
|
||||||
|
del dataset
|
||||||
|
return im_proj
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_geotransform(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 从图像坐标空间(行、列),也称为(像素、线)到地理参考坐标空间(投影或地理坐标)的仿射变换
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
geotransform = dataset.GetGeoTransform()
|
||||||
|
del dataset
|
||||||
|
return geotransform
|
||||||
|
|
||||||
|
def get_invgeotransform(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 从地理参考坐标空间(投影或地理坐标)的到图像坐标空间(行、列
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
geotransform = dataset.GetGeoTransform()
|
||||||
|
geotransform=gdal.InvGeoTransform(geotransform)
|
||||||
|
del dataset
|
||||||
|
return geotransform
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_bands(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 影像的波段数
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
bands = dataset.RasterCount
|
||||||
|
del dataset
|
||||||
|
return bands
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def geo2lonlat(dataset, x, y):
|
||||||
|
"""
|
||||||
|
将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定)
|
||||||
|
:param dataset: GDAL地理数据
|
||||||
|
:param x: 投影坐标x
|
||||||
|
:param y: 投影坐标y
|
||||||
|
:return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
|
||||||
|
"""
|
||||||
|
prosrs = osr.SpatialReference()
|
||||||
|
prosrs.ImportFromWkt(dataset.GetProjection())
|
||||||
|
geosrs = prosrs.CloneGeogCS()
|
||||||
|
ct = osr.CoordinateTransformation(prosrs, geosrs)
|
||||||
|
coords = ct.TransformPoint(x, y)
|
||||||
|
return coords[:2]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_band_array(filename, num=1):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:param num: 波段序号
|
||||||
|
:return: 对应波段的矩阵数据
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
bands = dataset.GetRasterBand(num)
|
||||||
|
array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize)
|
||||||
|
|
||||||
|
# if 'int' in str(array.dtype):
|
||||||
|
# array[np.where(array == -9999)] = np.inf
|
||||||
|
# else:
|
||||||
|
# array[np.where(array < -9000.0)] = np.nan
|
||||||
|
|
||||||
|
del dataset
|
||||||
|
return array
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_imgArray(filename, im_data, no_data='null'):
|
||||||
|
"""
|
||||||
|
影像保存
|
||||||
|
:param filename: 保存的路径
|
||||||
|
:param im_proj:
|
||||||
|
:param im_geotrans:
|
||||||
|
:param im_data:
|
||||||
|
:param no_data: 把无效值设置为 nodata
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
gdal_dtypes = {
|
||||||
|
'int8': gdal.GDT_Byte,
|
||||||
|
'unit16': gdal.GDT_UInt16,
|
||||||
|
'int16': gdal.GDT_Int16,
|
||||||
|
'unit32': gdal.GDT_UInt32,
|
||||||
|
'int32': gdal.GDT_Int32,
|
||||||
|
'float32': gdal.GDT_Float32,
|
||||||
|
'float64': gdal.GDT_Float64,
|
||||||
|
}
|
||||||
|
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
|
||||||
|
datatype = gdal_dtypes[im_data.dtype.name]
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
# 判读数组维数
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
|
||||||
|
# 创建文件
|
||||||
|
if os.path.exists(os.path.split(filename)[0]) is False:
|
||||||
|
os.makedirs(os.path.split(filename)[0])
|
||||||
|
|
||||||
|
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
|
||||||
|
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
|
||||||
|
|
||||||
|
if im_bands == 1:
|
||||||
|
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
|
||||||
|
outband = dataset.GetRasterBand(1)
|
||||||
|
outband.WriteArray(im_data)
|
||||||
|
if no_data != 'null':
|
||||||
|
outband.SetNoDataValue(no_data)
|
||||||
|
outband.FlushCache()
|
||||||
|
else:
|
||||||
|
for i in range(im_bands):
|
||||||
|
outband = dataset.GetRasterBand(1 + i)
|
||||||
|
outband.WriteArray(im_data[i])
|
||||||
|
outband.FlushCache()
|
||||||
|
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_data(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 获取所有波段的数据
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
im_width = dataset.RasterXSize
|
||||||
|
im_height = dataset.RasterYSize
|
||||||
|
im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
|
||||||
|
del dataset
|
||||||
|
return im_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scopes(ori_sim):
|
||||||
|
ori_sim_data = ImageHandler.get_data(ori_sim)
|
||||||
|
lon = ori_sim_data[0, :, :]
|
||||||
|
lat = ori_sim_data[1, :, :]
|
||||||
|
|
||||||
|
min_lon = np.nanmin(lon)
|
||||||
|
max_lon = np.nanmax(lon)
|
||||||
|
min_lat = np.nanmin(lat)
|
||||||
|
max_lat = np.nanmax(lat)
|
||||||
|
|
||||||
|
scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]]
|
||||||
|
return scopes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 获取所有波段的数据
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_band_array(filename):
|
||||||
|
"""
|
||||||
|
(大气延迟算法)
|
||||||
|
将ERA-5影像所有波段存为一个数组, 波段数在第三维度 get_data()->(37,8,8)
|
||||||
|
:param filename: 影像路径 get_all_band_array ->(8,8,37)
|
||||||
|
:return: 影像数组
|
||||||
|
"""
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
x_size = dataset.RasterXSize
|
||||||
|
y_size = dataset.RasterYSize
|
||||||
|
nums = dataset.RasterCount
|
||||||
|
array = np.zeros((y_size, x_size, nums), dtype=float)
|
||||||
|
if nums == 1:
|
||||||
|
bands_0 = dataset.GetRasterBand(1)
|
||||||
|
array = bands_0.ReadAsArray(0, 0, x_size, y_size)
|
||||||
|
else:
|
||||||
|
for i in range(0, nums):
|
||||||
|
bands = dataset.GetRasterBand(i+1)
|
||||||
|
arr = bands.ReadAsArray(0, 0, x_size, y_size)
|
||||||
|
array[:, :, i] = arr
|
||||||
|
return array
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_img_width(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 影像宽度
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
width = dataset.RasterXSize
|
||||||
|
|
||||||
|
del dataset
|
||||||
|
return width
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_img_height(filename):
|
||||||
|
"""
|
||||||
|
:param filename: tif路径
|
||||||
|
:return: 影像高度
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
height = dataset.RasterYSize
|
||||||
|
del dataset
|
||||||
|
return height
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_img(filename):
|
||||||
|
"""
|
||||||
|
影像读取
|
||||||
|
:param filename:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
img_dataset = gdal.Open(filename) # 打开文件
|
||||||
|
|
||||||
|
if img_dataset is None:
|
||||||
|
msg = 'Could not open ' + filename
|
||||||
|
print(msg)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
im_proj = img_dataset.GetProjection() # 地图投影信息
|
||||||
|
if im_proj is None:
|
||||||
|
return None, None, None
|
||||||
|
im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵
|
||||||
|
|
||||||
|
im_width = img_dataset.RasterXSize # 栅格矩阵的行数
|
||||||
|
im_height = img_dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height)
|
||||||
|
del img_dataset
|
||||||
|
return im_proj, im_geotrans, im_arr
|
||||||
|
|
||||||
|
def cal_img_scope(self, dataset):
|
||||||
|
"""
|
||||||
|
计算影像的地理坐标范围
|
||||||
|
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
|
||||||
|
:param dataset :GDAL地理数据
|
||||||
|
:return: list[point_upleft, point_upright, point_downleft, point_downright]
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
img_geotrans = dataset.GetGeoTransform()
|
||||||
|
if img_geotrans is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = dataset.RasterXSize # 栅格矩阵的列数
|
||||||
|
height = dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
|
||||||
|
point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0)
|
||||||
|
point_upright = self.trans_rowcol2geo(img_geotrans, width, 0)
|
||||||
|
point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height)
|
||||||
|
point_downright = self.trans_rowcol2geo(img_geotrans, width, height)
|
||||||
|
|
||||||
|
return [point_upleft, point_upright, point_downleft, point_downright]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scope_ori_sim(filename):
|
||||||
|
"""
|
||||||
|
计算影像的地理坐标范围
|
||||||
|
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
|
||||||
|
:param dataset :GDAL地理数据
|
||||||
|
:return: list[point_upleft, point_upright, point_downleft, point_downright]
|
||||||
|
"""
|
||||||
|
gdal.AllRegister()
|
||||||
|
dataset = gdal.Open(filename)
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = dataset.RasterXSize # 栅格矩阵的列数
|
||||||
|
height = dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
|
||||||
|
band1 = dataset.GetRasterBand(1)
|
||||||
|
array1 = band1.ReadAsArray(0, 0, band1.XSize, band1.YSize)
|
||||||
|
|
||||||
|
band2 = dataset.GetRasterBand(2)
|
||||||
|
array2 = band2.ReadAsArray(0, 0, band2.XSize, band2.YSize)
|
||||||
|
|
||||||
|
if array1[0, 0] < array1[0, width-1]:
|
||||||
|
point_upleft = [array1[0, 0], array2[0, 0]]
|
||||||
|
point_upright = [array1[0, width-1], array2[0, width-1]]
|
||||||
|
else:
|
||||||
|
point_upright = [array1[0, 0], array2[0, 0]]
|
||||||
|
point_upleft = [array1[0, width-1], array2[0, width-1]]
|
||||||
|
|
||||||
|
|
||||||
|
if array1[height-1, 0] < array1[height-1, width-1]:
|
||||||
|
point_downleft = [array1[height - 1, 0], array2[height - 1, 0]]
|
||||||
|
point_downright = [array1[height - 1, width - 1], array2[height - 1, width - 1]]
|
||||||
|
else:
|
||||||
|
point_downright = [array1[height - 1, 0], array2[height - 1, 0]]
|
||||||
|
point_downleft = [array1[height - 1, width - 1], array2[height - 1, width - 1]]
|
||||||
|
|
||||||
|
|
||||||
|
if(array2[0, 0] < array2[height - 1, 0]):
|
||||||
|
#上下调换顺序
|
||||||
|
tmp1 = point_upleft
|
||||||
|
point_upleft = point_downleft
|
||||||
|
point_downleft = tmp1
|
||||||
|
|
||||||
|
tmp2 = point_upright
|
||||||
|
point_upright = point_downright
|
||||||
|
point_downright = tmp2
|
||||||
|
|
||||||
|
return [point_upleft, point_upright, point_downleft, point_downright]
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trans_rowcol2geo(img_geotrans,img_col, img_row):
|
||||||
|
"""
|
||||||
|
据GDAL的六参数模型仿射矩阵将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
|
||||||
|
:param img_geotrans: 仿射矩阵
|
||||||
|
:param img_col:图像纵坐标
|
||||||
|
:param img_row:图像横坐标
|
||||||
|
:return: [geo_x,geo_y]
|
||||||
|
"""
|
||||||
|
geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2] * img_row
|
||||||
|
geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5] * img_row
|
||||||
|
return [geo_x, geo_y]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_era_into_img(filename, im_proj, im_geotrans, im_data):
|
||||||
|
"""
|
||||||
|
影像保存
|
||||||
|
:param filename:
|
||||||
|
:param im_proj:
|
||||||
|
:param im_geotrans:
|
||||||
|
:param im_data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
gdal_dtypes = {
|
||||||
|
'int8': gdal.GDT_Byte,
|
||||||
|
'unit16': gdal.GDT_UInt16,
|
||||||
|
'int16': gdal.GDT_Int16,
|
||||||
|
'unit32': gdal.GDT_UInt32,
|
||||||
|
'int32': gdal.GDT_Int32,
|
||||||
|
'float32': gdal.GDT_Float32,
|
||||||
|
'float64': gdal.GDT_Float64,
|
||||||
|
}
|
||||||
|
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
|
||||||
|
datatype = gdal_dtypes[im_data.dtype.name]
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
# 判读数组维数
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_height, im_width, im_bands = im_data.shape # shape[0] 行数
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
|
||||||
|
# 创建文件
|
||||||
|
if os.path.exists(os.path.split(filename)[0]) is False:
|
||||||
|
os.makedirs(os.path.split(filename)[0])
|
||||||
|
|
||||||
|
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
|
||||||
|
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
|
||||||
|
if im_bands == 1:
|
||||||
|
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
|
||||||
|
else:
|
||||||
|
for i in range(im_bands):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, i])
|
||||||
|
# dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
# 写GeoTiff文件
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def lat_lon_to_pixel(raster_dataset_path, location):
|
||||||
|
"""From zacharybears.com/using-python-to-translate-latlon-locations-to-pixels-on-a-geotiff/."""
|
||||||
|
gdal.AllRegister()
|
||||||
|
raster_dataset = gdal.Open(raster_dataset_path)
|
||||||
|
if raster_dataset is None:
|
||||||
|
return None
|
||||||
|
ds = raster_dataset
|
||||||
|
gt = ds.GetGeoTransform()
|
||||||
|
srs = osr.SpatialReference()
|
||||||
|
srs.ImportFromWkt(ds.GetProjection())
|
||||||
|
srs_lat_lon = srs.CloneGeogCS()
|
||||||
|
ct = osr.CoordinateTransformation(srs_lat_lon, srs)
|
||||||
|
new_location = [None, None]
|
||||||
|
# Change the point locations into the GeoTransform space
|
||||||
|
(new_location[1], new_location[0], holder) = ct.TransformPoint(location[1], location[0])
|
||||||
|
# Translate the x and y coordinates into pixel values
|
||||||
|
Xp = new_location[0]
|
||||||
|
Yp = new_location[1]
|
||||||
|
dGeoTrans = gt
|
||||||
|
dTemp = dGeoTrans[1] * dGeoTrans[5] - dGeoTrans[2] * dGeoTrans[4]
|
||||||
|
Xpixel = (dGeoTrans[5] * (Xp - dGeoTrans[0]) - dGeoTrans[2] * (Yp - dGeoTrans[3])) / dTemp
|
||||||
|
Yline = (dGeoTrans[1] * (Yp - dGeoTrans[3]) - dGeoTrans[4] * (Xp - dGeoTrans[0])) / dTemp
|
||||||
|
del raster_dataset
|
||||||
|
return (Xpixel, Yline)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_img(filename, im_proj, im_geotrans, im_data, no_data='0'):
|
||||||
|
"""
|
||||||
|
影像保存
|
||||||
|
:param filename: 保存的路径
|
||||||
|
:param im_proj:
|
||||||
|
:param im_geotrans:
|
||||||
|
:param im_data:
|
||||||
|
:param no_data: 把无效值设置为 nodata
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
gdal_dtypes = {
|
||||||
|
'int8': gdal.GDT_Byte,
|
||||||
|
'unit16': gdal.GDT_UInt16,
|
||||||
|
'int16': gdal.GDT_Int16,
|
||||||
|
'unit32': gdal.GDT_UInt32,
|
||||||
|
'int32': gdal.GDT_Int32,
|
||||||
|
'float32': gdal.GDT_Float32,
|
||||||
|
'float64': gdal.GDT_Float64,
|
||||||
|
}
|
||||||
|
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
|
||||||
|
datatype = gdal_dtypes[im_data.dtype.name]
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
flag = False
|
||||||
|
# 判读数组维数
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
flag = True
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
|
||||||
|
# 创建文件
|
||||||
|
if os.path.exists(os.path.split(filename)[0]) is False:
|
||||||
|
os.makedirs(os.path.split(filename)[0])
|
||||||
|
|
||||||
|
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
|
||||||
|
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
|
||||||
|
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
|
||||||
|
if im_bands == 1:
|
||||||
|
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
|
||||||
|
if flag:
|
||||||
|
outband = dataset.GetRasterBand(1)
|
||||||
|
outband.WriteArray(im_data[0])
|
||||||
|
if no_data != 'null':
|
||||||
|
outband.SetNoDataValue(np.double(no_data))
|
||||||
|
outband.FlushCache()
|
||||||
|
else:
|
||||||
|
outband = dataset.GetRasterBand(1)
|
||||||
|
outband.WriteArray(im_data)
|
||||||
|
if no_data != 'null':
|
||||||
|
outband.SetNoDataValue(np.double(no_data))
|
||||||
|
outband.FlushCache()
|
||||||
|
else:
|
||||||
|
for i in range(im_bands):
|
||||||
|
outband = dataset.GetRasterBand(1 + i)
|
||||||
|
outband.WriteArray(im_data[i])
|
||||||
|
if no_data != 'null':
|
||||||
|
outband.SetNoDataValue(np.double(no_data))
|
||||||
|
outband.FlushCache()
|
||||||
|
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
# 写GeoTiff文件
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_img_envi(filename, im_proj, im_geotrans, im_data, no_data='null'):
|
||||||
|
"""
|
||||||
|
影像保存
|
||||||
|
:param filename: 保存的路径
|
||||||
|
:param im_proj:
|
||||||
|
:param im_geotrans:
|
||||||
|
:param im_data:
|
||||||
|
:param no_data: 把无效值设置为 nodata
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
gdal_dtypes = {
|
||||||
|
'int8': gdal.GDT_Byte,
|
||||||
|
'unit16': gdal.GDT_UInt16,
|
||||||
|
'int16': gdal.GDT_Int16,
|
||||||
|
'unit32': gdal.GDT_UInt32,
|
||||||
|
'int32': gdal.GDT_Int32,
|
||||||
|
'float32': gdal.GDT_Float32,
|
||||||
|
'float64': gdal.GDT_Float64,
|
||||||
|
}
|
||||||
|
if not gdal_dtypes.get(im_data.dtype.name, None) is None:
|
||||||
|
datatype = gdal_dtypes[im_data.dtype.name]
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
# 判读数组维数
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
|
||||||
|
# 创建文件
|
||||||
|
if os.path.exists(os.path.split(filename)[0]) is False:
|
||||||
|
os.makedirs(os.path.split(filename)[0])
|
||||||
|
|
||||||
|
driver = gdal.GetDriverByName("ENVI") # 数据类型必须有,因为要计算需要多大内存空间
|
||||||
|
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
|
||||||
|
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
|
||||||
|
if im_bands == 1:
|
||||||
|
# outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据
|
||||||
|
outband = dataset.GetRasterBand(1)
|
||||||
|
outband.WriteArray(im_data)
|
||||||
|
if no_data != 'null':
|
||||||
|
outband.SetNoDataValue(no_data)
|
||||||
|
outband.FlushCache()
|
||||||
|
else:
|
||||||
|
for i in range(im_bands):
|
||||||
|
outband = dataset.GetRasterBand(1 + i)
|
||||||
|
outband.WriteArray(im_data[i])
|
||||||
|
outband.FlushCache()
|
||||||
|
# outRaster.GetRasterBand(i + 1).WriteArray(array[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict):
|
||||||
|
"""
|
||||||
|
图像中写入rpc信息
|
||||||
|
"""
|
||||||
|
# 判断栅格数据的数据类型
|
||||||
|
if 'int8' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_Byte
|
||||||
|
elif 'int16' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_Int16
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
# 判读数组维数
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
|
||||||
|
# 创建文件
|
||||||
|
driver = gdal.GetDriverByName("GTiff")
|
||||||
|
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
|
||||||
|
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
|
||||||
|
# 写入RPC参数
|
||||||
|
for k in rpc_dict.keys():
|
||||||
|
dataset.SetMetadataItem(k, rpc_dict[k], 'RPC')
|
||||||
|
|
||||||
|
if im_bands == 1:
|
||||||
|
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
|
||||||
|
else:
|
||||||
|
for i in range(im_bands):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||||
|
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
|
def transtif2mask(self,out_tif_path, in_tif_path, threshold):
|
||||||
|
"""
|
||||||
|
:param out_tif_path:输出路径
|
||||||
|
:param in_tif_path:输入的路径
|
||||||
|
:param threshold:阈值
|
||||||
|
"""
|
||||||
|
im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
|
||||||
|
im_arr_mask = (im_arr < threshold).astype(int)
|
||||||
|
self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)
|
||||||
|
|
||||||
|
def write_quick_view(self, tif_path, color_img=False, quick_view_path=None):
|
||||||
|
"""
|
||||||
|
生成快视图,默认快视图和影像同路径且同名
|
||||||
|
:param tif_path:影像路径
|
||||||
|
:param color_img:是否生成随机伪彩色图
|
||||||
|
:param quick_view_path:快视图路径
|
||||||
|
"""
|
||||||
|
if quick_view_path is None:
|
||||||
|
quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'
|
||||||
|
|
||||||
|
n = self.get_bands(tif_path)
|
||||||
|
if n == 1: # 单波段
|
||||||
|
t_data = self.get_data(tif_path)
|
||||||
|
else: # 多波段,转为强度数据
|
||||||
|
t_data = self.get_data(tif_path)
|
||||||
|
t_data = t_data.astype(float)
|
||||||
|
t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)
|
||||||
|
t_data[np.isnan(t_data)] = 0
|
||||||
|
t_data[np.where(t_data == -9999)] = 0
|
||||||
|
t_r = self.get_img_height(tif_path)
|
||||||
|
t_c = self.get_img_width(tif_path)
|
||||||
|
if t_r > 10000 or t_c > 10000:
|
||||||
|
q_r = int(t_r / 10)
|
||||||
|
q_c = int(t_c / 10)
|
||||||
|
elif 1024 < t_r < 10000 or 1024 < t_c < 10000:
|
||||||
|
if t_r > t_c:
|
||||||
|
q_r = 1024
|
||||||
|
q_c = int(t_c/t_r * 1024)
|
||||||
|
else:
|
||||||
|
q_c = 1024
|
||||||
|
q_r = int(t_r/t_c * 1024)
|
||||||
|
else:
|
||||||
|
q_r = t_r
|
||||||
|
q_c = t_c
|
||||||
|
|
||||||
|
if color_img is True:
|
||||||
|
# 生成伪彩色图
|
||||||
|
img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度)
|
||||||
|
u = np.unique(t_data)
|
||||||
|
for i in u:
|
||||||
|
if i != 0:
|
||||||
|
w = np.where(t_data == i)
|
||||||
|
img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
|
||||||
|
img[w[0], w[1], 1] = np.random.randint(0, 255)
|
||||||
|
img[w[0], w[1], 2] = np.random.randint(0, 255)
|
||||||
|
|
||||||
|
img = cv2.resize(img, (q_c, q_r)) # (宽,高)
|
||||||
|
cv2.imwrite(quick_view_path, img)
|
||||||
|
# cv2.imshow("result4", img)
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
else:
|
||||||
|
# 灰度图
|
||||||
|
min = np.percentile(t_data, 2) # np.nanmin(t_data)
|
||||||
|
max = np.percentile(t_data, 98) # np.nanmax(t_data)
|
||||||
|
# if (max - min) < 256:
|
||||||
|
t_data = (t_data - min) / (max - min) * 255
|
||||||
|
out_img = Image.fromarray(t_data)
|
||||||
|
out_img = out_img.resize((q_c, q_r)) # 重采样
|
||||||
|
out_img = out_img.convert("L") # 转换成灰度图
|
||||||
|
out_img.save(quick_view_path)
|
||||||
|
|
||||||
|
def get_center_scopes(self, dataset):
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
img_geotrans = dataset.GetGeoTransform()
|
||||||
|
if img_geotrans is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = dataset.RasterXSize # 栅格矩阵的列数
|
||||||
|
height = dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
|
||||||
|
x_split = int(width/5)
|
||||||
|
y_split = int(height/5)
|
||||||
|
img_col_start = x_split * 1
|
||||||
|
img_col_end = x_split * 3
|
||||||
|
img_row_start = y_split * 1
|
||||||
|
img_row_end = y_split *3
|
||||||
|
cols = img_col_end - img_col_start
|
||||||
|
rows = img_row_end - img_row_start
|
||||||
|
if cols > 10000 or rows > 10000:
|
||||||
|
img_col_end = img_col_start + 10000
|
||||||
|
img_row_end = img_row_start + 10000
|
||||||
|
|
||||||
|
point_upleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_start)
|
||||||
|
point_upright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_start)
|
||||||
|
point_downleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_end)
|
||||||
|
point_downright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_end)
|
||||||
|
|
||||||
|
return [point_upleft, point_upright, point_downleft, point_downright]
|
||||||
|
def write_view(self, tif_path, color_img=False, quick_view_path=None):
|
||||||
|
"""
|
||||||
|
生成快视图,默认快视图和影像同路径且同名
|
||||||
|
:param tif_path:影像路径
|
||||||
|
:param color_img:是否生成随机伪彩色图
|
||||||
|
:param quick_view_path:快视图路径
|
||||||
|
"""
|
||||||
|
if quick_view_path is None:
|
||||||
|
quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'
|
||||||
|
|
||||||
|
n = self.get_bands(tif_path)
|
||||||
|
if n == 1: # 单波段
|
||||||
|
t_data = self.get_data(tif_path)
|
||||||
|
else: # 多波段,转为强度数据
|
||||||
|
t_data = self.get_data(tif_path)
|
||||||
|
t_data = t_data.astype(float)
|
||||||
|
t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)
|
||||||
|
t_data[np.isnan(t_data)] = 0
|
||||||
|
t_data[np.where(t_data == -9999)] = 0
|
||||||
|
t_r = self.get_img_height(tif_path)
|
||||||
|
t_c = self.get_img_width(tif_path)
|
||||||
|
q_r = t_r
|
||||||
|
q_c = t_c
|
||||||
|
|
||||||
|
if color_img is True:
|
||||||
|
# 生成伪彩色图
|
||||||
|
img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度)
|
||||||
|
u = np.unique(t_data)
|
||||||
|
for i in u:
|
||||||
|
if i != 0:
|
||||||
|
w = np.where(t_data == i)
|
||||||
|
img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
|
||||||
|
img[w[0], w[1], 1] = np.random.randint(0, 255)
|
||||||
|
img[w[0], w[1], 2] = np.random.randint(0, 255)
|
||||||
|
|
||||||
|
img = cv2.resize(img, (q_c, q_r)) # (宽,高)
|
||||||
|
cv2.imwrite(quick_view_path, img)
|
||||||
|
# cv2.imshow("result4", img)
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
else:
|
||||||
|
# 灰度图
|
||||||
|
min = np.percentile(t_data, 2) # np.nanmin(t_data)
|
||||||
|
max = np.percentile(t_data, 98) # np.nanmax(t_data)
|
||||||
|
# if (max - min) < 256:
|
||||||
|
t_data = (t_data - min) / (max - min) * 255
|
||||||
|
out_img = Image.fromarray(t_data)
|
||||||
|
out_img = out_img.resize((q_c, q_r)) # 重采样
|
||||||
|
out_img = out_img.convert("L") # 转换成灰度图
|
||||||
|
out_img.save(quick_view_path)
|
||||||
|
|
||||||
|
return quick_view_path
|
||||||
|
|
||||||
|
def limit_field(self, out_path, in_path, min_value, max_value):
|
||||||
|
"""
|
||||||
|
:param out_path:输出路径
|
||||||
|
:param in_path:主mask路径,输出影像采用主mask的地理信息
|
||||||
|
:param min_value
|
||||||
|
:param max_value
|
||||||
|
"""
|
||||||
|
proj = self.get_projection(in_path)
|
||||||
|
geotrans = self.get_geotransform(in_path)
|
||||||
|
array = self.get_band_array(in_path, 1)
|
||||||
|
array[array < min_value] = min_value
|
||||||
|
array[array > max_value] = max_value
|
||||||
|
self.write_img(out_path, proj, geotrans, array)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def band_merge(self, lon, lat, ori_sim):
|
||||||
|
lon_arr = self.get_data(lon)
|
||||||
|
lat_arr = self.get_data(lat)
|
||||||
|
temp = np.zeros((2, lon_arr.shape[0], lon_arr.shape[1]), dtype=float)
|
||||||
|
temp[0, :, :] = lon_arr[:, :]
|
||||||
|
temp[1, :, :] = lat_arr[:, :]
|
||||||
|
self.write_img(ori_sim, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], temp, '0')
|
||||||
|
|
||||||
|
|
||||||
|
def get_scopes(self, ori_sim):
|
||||||
|
ori_sim_data = self.get_data(ori_sim)
|
||||||
|
lon = ori_sim_data[0, :, :]
|
||||||
|
lat = ori_sim_data[1, :, :]
|
||||||
|
|
||||||
|
min_lon = np.nanmin(np.where((lon != 0) & ~np.isnan(lon), lon, np.inf))
|
||||||
|
max_lon = np.nanmax(np.where((lon != 0) & ~np.isnan(lon), lon, -np.inf))
|
||||||
|
min_lat = np.nanmin(np.where((lat != 0) & ~np.isnan(lat), lat, np.inf))
|
||||||
|
max_lat = np.nanmax(np.where((lat != 0) & ~np.isnan(lat), lat, -np.inf))
|
||||||
|
|
||||||
|
scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]]
|
||||||
|
return scopes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dem_merged(in_dem_path, out_dem_path):
|
||||||
|
'''
|
||||||
|
DEM重采样函数,默认坐标系为WGS84
|
||||||
|
agrs:
|
||||||
|
in_dem_path: 输入的DEM文件夹路径
|
||||||
|
meta_file_path: 输入的xml元文件路径
|
||||||
|
out_dem_path: 输出的DEM文件夹路径
|
||||||
|
'''
|
||||||
|
# 读取文件夹中所有的DEM
|
||||||
|
dem_file_paths = [os.path.join(in_dem_path, dem_name) for dem_name in os.listdir(in_dem_path) if
|
||||||
|
dem_name.find(".tif") >= 0 and dem_name.find(".tif.") == -1]
|
||||||
|
spatialreference = osr.SpatialReference()
|
||||||
|
spatialreference.SetWellKnownGeogCS("WGS84") # 设置地理坐标,单位为度 degree # 设置投影坐标,单位为度 degree
|
||||||
|
spatialproj = spatialreference.ExportToWkt() # 导出投影结果
|
||||||
|
# 将DEM拼接成一张大图
|
||||||
|
mergeFile = gdal.BuildVRT(os.path.join(out_dem_path, "mergedDEM_VRT.tif"), dem_file_paths)
|
||||||
|
out_DEM = os.path.join(out_dem_path, "mergedDEM.tif")
|
||||||
|
gdal.Warp(out_DEM,
|
||||||
|
mergeFile,
|
||||||
|
format="GTiff",
|
||||||
|
dstSRS=spatialproj,
|
||||||
|
dstNodata=-9999,
|
||||||
|
outputType=gdal.GDT_Float32)
|
||||||
|
time.sleep(3)
|
||||||
|
# gdal.CloseDir(out_DEM)
|
||||||
|
return out_DEM
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def landcover_merged(in_dem_path, out_dem_path):
|
||||||
|
'''
|
||||||
|
DEM重采样函数,默认坐标系为WGS84
|
||||||
|
agrs:
|
||||||
|
in_dem_path: 输入的DEM文件夹路径
|
||||||
|
meta_file_path: 输入的xml元文件路径
|
||||||
|
out_dem_path: 输出的DEM文件夹路径
|
||||||
|
'''
|
||||||
|
# 读取文件夹中所有的DEM
|
||||||
|
dem_file_paths = [os.path.join(in_dem_path, dem_name) for dem_name in os.listdir(in_dem_path) if
|
||||||
|
dem_name.find(".tif") >= 0 and dem_name.find(".tif.") == -1]
|
||||||
|
spatialreference = osr.SpatialReference()
|
||||||
|
spatialreference.SetWellKnownGeogCS("WGS84") # 设置地理坐标,单位为度 degree # 设置投影坐标,单位为度 degree
|
||||||
|
spatialproj = spatialreference.ExportToWkt() # 导出投影结果
|
||||||
|
# 将DEM拼接成一张大图
|
||||||
|
mergeFile = gdal.BuildVRT(os.path.join(out_dem_path, "mergedDEM_VRT.tif"), dem_file_paths)
|
||||||
|
out_DEM = os.path.join(out_dem_path, "mergedDEM.tif")
|
||||||
|
gdal.Warp(out_DEM,
|
||||||
|
mergeFile,
|
||||||
|
format="GTiff",
|
||||||
|
dstSRS=spatialproj,
|
||||||
|
dstNodata=-9999,
|
||||||
|
outputType=gdal.GDT_Byte)
|
||||||
|
time.sleep(3)
|
||||||
|
# gdal.CloseDir(out_DEM)
|
||||||
|
return out_DEM
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_inc_angle(inc_xml, rows, cols, out_path):
|
||||||
|
tree = ElementTree()
|
||||||
|
tree.parse(inc_xml) # 影像头文件
|
||||||
|
root = tree.getroot()
|
||||||
|
values = root.findall('incidenceValue')
|
||||||
|
angle_value = [value.text for value in values]
|
||||||
|
angle_value = np.array(angle_value)
|
||||||
|
inc_angle = np.tile(angle_value, (rows, 1))
|
||||||
|
ImageHandler.write_img(out_path, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], inc_angle)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cols = 7086
|
||||||
|
rows = 8064
|
||||||
|
inc_xml = r"D:\micro\WorkSpace\Dem\Temporary\processing\product\GF3_SAY_FSI_001614_E113.2_N34.5_20161129_L1A_HHHV_L10002015686-DEM.tiff"
|
||||||
|
# ImageHandler().write_quick_view(inc_xml)
|
||||||
|
# fn = r'E:\202306hb\result\GF3B_SYC_QPSI_008316_E116.1_N43.3_20230622_L1A_AHV_L10000202892-cal-SMC.tif'
|
||||||
|
# out = r'E:\202306hb\result\soil.tif'
|
||||||
|
# #
|
||||||
|
# im_proj, im_geotrans, im_arr = ImageHandler.read_img(fn)
|
||||||
|
# im_arr[np.where(np.isnan(im_arr))] = 0
|
||||||
|
# # h,w = im_arr.shape
|
||||||
|
# # arr = np.random.rand(h,w)*0.4
|
||||||
|
# ImageHandler.write_img(out, im_proj, im_geotrans, im_arr, '-1')
|
||||||
|
|
||||||
|
# path = r'D:\BaiduNetdiskDownload\GZ\lon.rdr'
|
||||||
|
# path2 = r'D:\BaiduNetdiskDownload\GZ\lat.rdr'
|
||||||
|
# path3 = r'D:\BaiduNetdiskDownload\GZ\lon_lat.tif'
|
||||||
|
# s = ImageHandler().band_merge(path, path2, path3)
|
||||||
|
# print(s)
|
||||||
|
# pass
|
||||||
|
fileDir = r'D:\micro\WorkSpace\ortho\Temporary\VH_db'
|
||||||
|
outDir = r'D:\micro\WorkSpace\ortho\Temporary\VH_db\test'
|
||||||
|
files = os.listdir(fileDir)
|
||||||
|
for file in files:
|
||||||
|
tifFile = os.path.join(fileDir, file)
|
||||||
|
outFile = os.path.join(outDir, file)
|
||||||
|
im_proj, im_geotrans, im_arr = ImageHandler.read_img(tifFile)
|
||||||
|
im_arr[np.isnan(im_arr)] = 0
|
||||||
|
im_arr[np.isinf(im_arr)] = 0
|
||||||
|
ImageHandler.write_img(outFile, im_proj, im_geotrans, im_arr, '0')
|
||||||
|
|
@ -0,0 +1,508 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from osgeo import ogr,gdal
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from osgeo import gdal
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import matplotlib.patches as patches
|
||||||
|
from osgeo import gdal
|
||||||
|
from PIL import Image
|
||||||
|
from scipy.spatial import cKDTree
|
||||||
|
import numpy as np
|
||||||
|
from DotaOperator import DotaObj,createDota,readDotaFile,writerDotaFile
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from math import ceil, floor
|
||||||
|
##########################################################################
|
||||||
|
# 函数区
|
||||||
|
##########################################################################
|
||||||
|
def read_tif(path):
|
||||||
|
dataset = gdal.Open(path) # 打开TIF文件
|
||||||
|
if dataset is None:
|
||||||
|
print("无法打开文件")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
cols = dataset.RasterXSize # 图像宽度
|
||||||
|
rows = dataset.RasterYSize # 图像高度
|
||||||
|
bands = dataset.RasterCount
|
||||||
|
im_proj = dataset.GetProjection() # 获取投影信息
|
||||||
|
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
|
||||||
|
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
|
||||||
|
print("行数:", rows)
|
||||||
|
print("列数:", cols)
|
||||||
|
print("波段:", bands)
|
||||||
|
del dataset # 关闭数据集
|
||||||
|
return im_proj, im_Geotrans, im_data
|
||||||
|
|
||||||
|
def write_envi(im_data, im_geotrans, im_proj, output_path):
|
||||||
|
"""
|
||||||
|
将数组数据写入ENVI格式文件
|
||||||
|
:param im_data: 输入的numpy数组(2D或3D)
|
||||||
|
:param im_geotrans: 仿射变换参数(6元组)
|
||||||
|
:param im_proj: 投影信息(WKT字符串)
|
||||||
|
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr)
|
||||||
|
"""
|
||||||
|
im_bands = 1
|
||||||
|
im_height, im_width = im_data.shape
|
||||||
|
# 创建ENVI格式驱动
|
||||||
|
driver = gdal.GetDriverByName("ENVI")
|
||||||
|
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
|
||||||
|
|
||||||
|
if dataset is not None:
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 设置投影
|
||||||
|
|
||||||
|
dataset.GetRasterBand(1).WriteArray(im_data)
|
||||||
|
|
||||||
|
dataset.FlushCache() # 确保数据写入磁盘
|
||||||
|
dataset = None # 关闭文件
|
||||||
|
|
||||||
|
def geoXY2pixelXY(geo_x, geo_y, inv_gt):
|
||||||
|
pixel_x = inv_gt[0] + geo_x * inv_gt[1] + geo_y * inv_gt[2]
|
||||||
|
pixel_y = inv_gt[3] + geo_x * inv_gt[4] + geo_y * inv_gt[5]
|
||||||
|
return pixel_x, pixel_y
|
||||||
|
|
||||||
|
def label2pixelpoints(dotapath,tiff_inv_trans,methodstr):
|
||||||
|
dotalist = readDotaFile(dotapath)
|
||||||
|
if methodstr=="geolabel":
|
||||||
|
for i in range(len(dotalist)):
|
||||||
|
geo_x = dotalist[i].x1 # x1
|
||||||
|
geo_y = dotalist[i].y1
|
||||||
|
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
|
||||||
|
dotalist[i].x1 = pixel_x
|
||||||
|
dotalist[i].y1 = pixel_y
|
||||||
|
|
||||||
|
geo_x = dotalist[i].x2 # x2
|
||||||
|
geo_y = dotalist[i].y2
|
||||||
|
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
|
||||||
|
dotalist[i].x2 = pixel_x
|
||||||
|
dotalist[i].y2 = pixel_y
|
||||||
|
|
||||||
|
geo_x = dotalist[i].x3 # x3
|
||||||
|
geo_y = dotalist[i].y3
|
||||||
|
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
|
||||||
|
dotalist[i].x3 = pixel_x
|
||||||
|
dotalist[i].y3 = pixel_y
|
||||||
|
|
||||||
|
geo_x = dotalist[i].x4 # x4
|
||||||
|
geo_y = dotalist[i].y4
|
||||||
|
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
|
||||||
|
dotalist[i].x4 = pixel_x
|
||||||
|
dotalist[i].y4 = pixel_y
|
||||||
|
|
||||||
|
print("点数:", len(dotalist))
|
||||||
|
return dotalist
|
||||||
|
|
||||||
|
def getMaxEdge(dotalist, ids):
|
||||||
|
cornpoint = np.zeros((len(ids) * 4, 2))
|
||||||
|
for idx in range(len(ids)):
|
||||||
|
cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1
|
||||||
|
cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2
|
||||||
|
cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3
|
||||||
|
cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4
|
||||||
|
|
||||||
|
cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1
|
||||||
|
cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2
|
||||||
|
cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3
|
||||||
|
cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4
|
||||||
|
|
||||||
|
xedge = np.max(cornpoint[:, 0]) - np.min(cornpoint[:, 0])
|
||||||
|
yedge = np.max(cornpoint[:, 1]) - np.min(cornpoint[:, 1])
|
||||||
|
|
||||||
|
edgelen = xedge if xedge > yedge else yedge
|
||||||
|
return edgelen
|
||||||
|
|
||||||
|
def getExternCenter(dotalist, ids):
|
||||||
|
cornpoint = np.zeros((len(ids) * 4, 2))
|
||||||
|
for idx in range(len(ids)):
|
||||||
|
cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1
|
||||||
|
cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2
|
||||||
|
cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3
|
||||||
|
cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4
|
||||||
|
|
||||||
|
cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1
|
||||||
|
cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2
|
||||||
|
cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3
|
||||||
|
cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4
|
||||||
|
|
||||||
|
minX = np.min(cornpoint[:, 0])
|
||||||
|
minY = np.min(cornpoint[:, 1])
|
||||||
|
maxX = np.max(cornpoint[:, 0])
|
||||||
|
maxY = np.max(cornpoint[:, 1])
|
||||||
|
centerX = (minX + maxX) / 2
|
||||||
|
centerY = (minY + maxY) / 2
|
||||||
|
return [centerX, centerY, minX, minY, maxX, maxY]
|
||||||
|
|
||||||
|
def drawSliceRasterPrivew(tiff_data,dotalist,clusterDict):
|
||||||
|
# 绘制图形
|
||||||
|
# 创建图形和坐标轴
|
||||||
|
fig, ax = plt.subplots(figsize=(20, 16))
|
||||||
|
ax.imshow(tiff_data, cmap='gray')
|
||||||
|
# 绘制每个目标的矩形框并标注坐标
|
||||||
|
for i in range(len(dotalist)):
|
||||||
|
# 提取x和y坐标
|
||||||
|
x_coords = [dotalist[i].x1, dotalist[i].x2, dotalist[i].x3, dotalist[i].x4]
|
||||||
|
y_coords = [dotalist[i].y1, dotalist[i].y2, dotalist[i].y3, dotalist[i].y4]
|
||||||
|
|
||||||
|
# 计算最小外接矩形(AABB)
|
||||||
|
x_min, x_max = min(x_coords), max(x_coords)
|
||||||
|
y_min, y_max = min(y_coords), max(y_coords)
|
||||||
|
width = x_max - x_min
|
||||||
|
height = y_max - y_min
|
||||||
|
|
||||||
|
# 绘制无填充矩形框(仅红色边框)
|
||||||
|
rect = patches.Rectangle(
|
||||||
|
(x_min, y_min), width, height,
|
||||||
|
linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none'
|
||||||
|
)
|
||||||
|
ax.add_patch(rect)
|
||||||
|
|
||||||
|
# ax.annotate(f'({x},{y})', xy=(x, y), xytext=(5, 5),
|
||||||
|
# textcoords='offset points', fontsize=10,
|
||||||
|
# bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8))
|
||||||
|
|
||||||
|
# 在矩形中心标注目标编号
|
||||||
|
center_x = sum(x_coords) / 4
|
||||||
|
center_y = sum(y_coords) / 4
|
||||||
|
ax.text(center_x, center_y, str(i),
|
||||||
|
ha='center', va='center', fontsize=6, color='red')
|
||||||
|
|
||||||
|
# 以类别中心为中心绘制四边形
|
||||||
|
for k in clusterDict:
|
||||||
|
# 绘制无填充矩形框(仅红色边框)
|
||||||
|
minX = clusterDict[k]["p"][0]
|
||||||
|
minY = clusterDict[k]["p"][1]
|
||||||
|
rect = patches.Rectangle(
|
||||||
|
(minX , minY), 1024, 1024,
|
||||||
|
linewidth=2, edgecolor='green', facecolor='none' # 关键:facecolor='none'
|
||||||
|
)
|
||||||
|
ax.add_patch(rect)
|
||||||
|
ax.text(minX+512, minY+512, str(k),
|
||||||
|
ha='center', va='center', fontsize=6, color='green')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print("绘图结束")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_optimal_slices(H, W, boxes, patch_size=1024, max_overlap_rate=0.2):
|
||||||
|
"""
|
||||||
|
Compute optimal slice positions for the image to maximize the number of fully contained rectangular patches (boxes),
|
||||||
|
while ensuring the overlap rate between any two slices does not exceed the specified maximum.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- H: Height of the image.
|
||||||
|
- W: Width of the image.
|
||||||
|
- boxes: List of tuples or lists, each containing (x1, y1, x2, y2) where (x1, y1) is the top-left and (x2, y2) is the bottom-right of a rectangular patch.
|
||||||
|
- patch_size: Size of each slice (square, e.g., 1024).
|
||||||
|
- max_overlap_rate: Maximum allowed overlap rate (e.g., 0.2).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- slices: List of (sx, sy) starting positions for the slices.
|
||||||
|
- covered_count: Number of patches that appear fully in at least one slice.
|
||||||
|
"""
|
||||||
|
overlap_max = patch_size * max_overlap_rate
|
||||||
|
stride = patch_size - floor(overlap_max) # Ensures overlap <= max_overlap_rate in linear dimensions
|
||||||
|
|
||||||
|
N = len(boxes)
|
||||||
|
x_covered = [set() for _ in range(stride)]
|
||||||
|
y_covered = [set() for _ in range(stride)]
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
x1, y1, x2, y2 = boxes[i]
|
||||||
|
b_w = x2 - x1
|
||||||
|
b_h = y2 - y1
|
||||||
|
|
||||||
|
# For x-dimension
|
||||||
|
lx = max(0, x2 - patch_size)
|
||||||
|
rx = min(W - patch_size, x1)
|
||||||
|
if lx <= rx:
|
||||||
|
start_x = ceil(lx)
|
||||||
|
end_x = floor(rx)
|
||||||
|
l_x = end_x - start_x + 1
|
||||||
|
if l_x > 0:
|
||||||
|
if l_x >= stride:
|
||||||
|
for ox in range(stride):
|
||||||
|
x_covered[ox].add(i)
|
||||||
|
else:
|
||||||
|
for sx in range(start_x, end_x + 1):
|
||||||
|
ox = sx % stride
|
||||||
|
x_covered[ox].add(i)
|
||||||
|
|
||||||
|
# For y-dimension
|
||||||
|
ly = max(0, y2 - patch_size)
|
||||||
|
ry = min(H - patch_size, y1)
|
||||||
|
if ly <= ry:
|
||||||
|
start_y = ceil(ly)
|
||||||
|
end_y = floor(ry)
|
||||||
|
l_y = end_y - start_y + 1
|
||||||
|
if l_y > 0:
|
||||||
|
if l_y >= stride:
|
||||||
|
for oy in range(stride):
|
||||||
|
y_covered[oy].add(i)
|
||||||
|
else:
|
||||||
|
for sy in range(start_y, end_y + 1):
|
||||||
|
oy = sy % stride
|
||||||
|
y_covered[oy].add(i)
|
||||||
|
|
||||||
|
# Find the best offset pair (ox, oy) that maximizes covered patches
|
||||||
|
max_covered = 0
|
||||||
|
best_ox = 0
|
||||||
|
best_oy = 0
|
||||||
|
for ox in range(stride):
|
||||||
|
for oy in range(stride):
|
||||||
|
current_covered = len(x_covered[ox] & y_covered[oy])
|
||||||
|
if current_covered > max_covered:
|
||||||
|
max_covered = current_covered
|
||||||
|
best_ox = ox
|
||||||
|
best_oy = oy
|
||||||
|
|
||||||
|
# Generate the slice positions using the best offsets and stride
|
||||||
|
slices = []
|
||||||
|
sx = best_ox
|
||||||
|
while sx + patch_size <= W:
|
||||||
|
sy = best_oy
|
||||||
|
while sy + patch_size <= H:
|
||||||
|
slices.append((sx, sy))
|
||||||
|
sy += stride
|
||||||
|
sx += stride
|
||||||
|
|
||||||
|
return slices, max_covered
|
||||||
|
|
||||||
|
def check_B_in_A(A,B):
|
||||||
|
"""
|
||||||
|
判断A包含B
|
||||||
|
:param A: [x0,y0.w.h]
|
||||||
|
:param B: [x0,y0.w.h]
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 解构矩形A和B的参数
|
||||||
|
Ax0, Ay0, Aw, Ah = A
|
||||||
|
Bx0, By0, Bw, Bh = B
|
||||||
|
|
||||||
|
# 计算矩形A和B的右边界和下边界
|
||||||
|
Ax1 = Ax0 + Aw
|
||||||
|
Ay1 = Ay0 + Ah
|
||||||
|
Bx1 = Bx0 + Bw
|
||||||
|
By1 = By0 + Bh
|
||||||
|
|
||||||
|
# 判断B是否完全在A内部
|
||||||
|
return (Bx0 >= Ax0) and (Bx1 <= Ax1) and (By0 >= Ay0) and (By1 <= Ay1)
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# 切分算法流程图
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
def getclusterDict(dotalist,imgheight,imgwidth,pitchSize=1024,max_overlap_rate=0.2):
|
||||||
|
"""
|
||||||
|
生成切片数据
|
||||||
|
:param dotalist: 样本集
|
||||||
|
:param imgheight: 图像高度
|
||||||
|
:param imgwidth: 图像宽度
|
||||||
|
:return: 切片类型
|
||||||
|
"""
|
||||||
|
boxs=[]
|
||||||
|
for i in range(len(dotalist)):
|
||||||
|
xs=np.array([dotalist[i].x1,dotalist[i].x2,dotalist[i].x3, dotalist[i].x4])
|
||||||
|
ys=np.array([dotalist[i].y1,dotalist[i].y2,dotalist[i].y3, dotalist[i].y4])
|
||||||
|
x1=np.min(xs)
|
||||||
|
x2=np.max(xs)
|
||||||
|
y1=np.min(ys)
|
||||||
|
y2=np.max(ys)
|
||||||
|
boxs.append([x1,y1,x2,y2]) # x1, y1, x2, y2 = boxes[i]
|
||||||
|
|
||||||
|
slices, max_covered=find_optimal_slices(imgheight,imgwidth,boxs,pitchSize,max_overlap_rate)
|
||||||
|
|
||||||
|
clusterDict={}
|
||||||
|
|
||||||
|
waitContaindota=[]
|
||||||
|
hasContainIds=[]
|
||||||
|
for i in range(len(slices)):
|
||||||
|
sx,sy=slices[i]
|
||||||
|
clusterDict[i]={"p":[sx,sy],"id":[]}
|
||||||
|
slicesExten=[sx,sy,1024,1024]
|
||||||
|
for ids in range(len(dotalist)):
|
||||||
|
if ids in hasContainIds:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
[centerX, centerY, minX, minY, maxX, maxY]=getExternCenter(dotalist, [ids])
|
||||||
|
dotaExtend=[minX,minY,maxX-minX,maxY-minY]
|
||||||
|
if check_B_in_A(slicesExten,dotaExtend):
|
||||||
|
clusterDict[i]["id"].append(ids)
|
||||||
|
hasContainIds.append(ids)
|
||||||
|
|
||||||
|
for ids in range(len(dotalist)):
|
||||||
|
if ids in hasContainIds:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
waitContaindota.append(ids)
|
||||||
|
print("No in slice dota : ",str(dotalist[ids]) )
|
||||||
|
|
||||||
|
print("no process ids ",str(waitContaindota))
|
||||||
|
return clusterDict
|
||||||
|
|
||||||
|
|
||||||
|
def drawSlictplot(clusterDict,dotalist,tiff_data,nrows=10,ncols=9):
|
||||||
|
"""
|
||||||
|
:param clusterDict: clusterDict[i]={"p":[sx,sy],"id":[]}
|
||||||
|
:param dotalist: (x1, y1, x2, y2, x3, y3, x4, y4 clsname diffcule)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(20, 16))
|
||||||
|
plt.tight_layout(pad=3.0)
|
||||||
|
|
||||||
|
# 9*10
|
||||||
|
subid=0
|
||||||
|
for cid in clusterDict:
|
||||||
|
sx,sy=clusterDict[cid]["p"]
|
||||||
|
colid=subid//nrows
|
||||||
|
rowid=subid%nrows
|
||||||
|
subid=subid+1
|
||||||
|
ax = axes[rowid, colid]
|
||||||
|
ax.set_title(str(cid))
|
||||||
|
sliceData=tiff_data[sy:(sy+1024),sx:(sx+1024)]
|
||||||
|
ax.imshow(sliceData, cmap='gray')
|
||||||
|
|
||||||
|
for did in clusterDict[cid]["id"] :
|
||||||
|
# 提取x和y坐标
|
||||||
|
x_coords = [dotalist[did].x1-sx, dotalist[did].x2-sx, dotalist[did].x3-sx, dotalist[did].x4-sx]
|
||||||
|
y_coords = [dotalist[did].y1-sy, dotalist[did].y2-sy, dotalist[did].y3-sy, dotalist[did].y4-sy]
|
||||||
|
|
||||||
|
# 计算最小外接矩形(AABB)
|
||||||
|
x_min, x_max = min(x_coords), max(x_coords)
|
||||||
|
y_min, y_max = min(y_coords), max(y_coords)
|
||||||
|
width = x_max - x_min
|
||||||
|
height = y_max - y_min
|
||||||
|
|
||||||
|
|
||||||
|
# 绘制无填充矩形框(仅红色边框)
|
||||||
|
rect = patches.Rectangle(
|
||||||
|
(x_min, y_min), width, height,
|
||||||
|
linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none'
|
||||||
|
)
|
||||||
|
ax.add_patch(rect)
|
||||||
|
|
||||||
|
# 在矩形中心标注目标编号
|
||||||
|
center_x = x_min+width/2
|
||||||
|
center_y = y_min+height/2
|
||||||
|
ax.text(center_x, center_y, str(did),
|
||||||
|
ha='center', va='center', fontsize=6, color='red')
|
||||||
|
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print("绘图结束")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def slictDataAndOutlabel(clusterDict,dotalist,tiff_data,tiff_basename,outfolderpath,im_geotrans, im_proj):
|
||||||
|
"""
|
||||||
|
切分标签,输出结果与文件
|
||||||
|
:param clusterDict:
|
||||||
|
:param dotalist:
|
||||||
|
:param tiff_data:
|
||||||
|
:param tiff_name:
|
||||||
|
:param outfolderpath:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for cid in clusterDict:
|
||||||
|
sx, sy = clusterDict[cid]["p"]
|
||||||
|
sliceData = tiff_data[sy:(sy + 1024), sx:(sx + 1024)]
|
||||||
|
outbinname="{}_{}.bin".format(tiff_basename,cid)
|
||||||
|
outlabelname="{}_{}.txt".format(tiff_basename,cid)
|
||||||
|
# 获取样本列表
|
||||||
|
outdotalist=[]
|
||||||
|
for did in clusterDict[cid]["id"] :
|
||||||
|
tempdota=dotalist[did]
|
||||||
|
tempdota.x1=tempdota.x1-sx
|
||||||
|
tempdota.x2=tempdota.x2-sx
|
||||||
|
tempdota.x3=tempdota.x3-sx
|
||||||
|
tempdota.x4=tempdota.x4-sx
|
||||||
|
tempdota.y1=tempdota.y1-sy
|
||||||
|
tempdota.y2=tempdota.y2-sy
|
||||||
|
tempdota.y3=tempdota.y3-sy
|
||||||
|
tempdota.y4=tempdota.y4-sy
|
||||||
|
outdotalist.append(tempdota)
|
||||||
|
|
||||||
|
outlabelpath=os.path.join(outfolderpath,outlabelname)
|
||||||
|
outbinpath=os.path.join(outfolderpath,outbinname)
|
||||||
|
|
||||||
|
temp_im_geotrans=[tempi for tempi in im_geotrans]
|
||||||
|
# 处理 0,3
|
||||||
|
temp_im_geotrans[0]=im_geotrans[0]+sx*im_geotrans[1]+im_geotrans[2]*sy # x
|
||||||
|
temp_im_geotrans[3]=im_geotrans[3]+sx*im_geotrans[4]+im_geotrans[5]*sy # y
|
||||||
|
write_envi(sliceData,temp_im_geotrans,im_proj,outbinpath)
|
||||||
|
writerDotaFile(outdotalist,outlabelpath)
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# 处理流程图
|
||||||
|
##########################################################################
|
||||||
|
def DataSampleSliceRasterProcess(inbinfile,labelfilepath,outfolderpath,methodstr):
|
||||||
|
tiff_proj, tiff_trans, tiff_data = read_tif(inbinfile)
|
||||||
|
tiff_inv_trans = gdal.InvGeoTransform(tiff_trans)
|
||||||
|
dotalist=label2pixelpoints(labelfilepath,tiff_inv_trans,methodstr)
|
||||||
|
imgheight, imgwidth=tiff_data.shape
|
||||||
|
clusterDict=getclusterDict(dotalist,imgheight,imgwidth)
|
||||||
|
drawSliceRasterPrivew(tiff_data, dotalist, clusterDict)
|
||||||
|
ncols=int(len(clusterDict)/9+1)
|
||||||
|
drawSlictplot(clusterDict, dotalist, tiff_data, 9, ncols)
|
||||||
|
tiff_name=os.path.basename(inbinfile)
|
||||||
|
tiff_basename=os.path.splitext(tiff_name)[0]
|
||||||
|
slictDataAndOutlabel(clusterDict, dotalist, tiff_data, tiff_basename, outfolderpath, tiff_trans, tiff_proj)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def getParams():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-i','--inbinfile',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.bin', help='输入tiff的bin文件')
|
||||||
|
parser.add_argument('-l', '--labelfilepath',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\标注\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01_LC.txt", help='输入标注')
|
||||||
|
parser.add_argument('-o', '--outfolderpath',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\切分3', help='切片文件夹地址')
|
||||||
|
group = parser.add_mutually_exclusive_group()
|
||||||
|
group.add_argument(
|
||||||
|
'--geolabel',
|
||||||
|
action='store_const',
|
||||||
|
const='geolabel',
|
||||||
|
dest='method',
|
||||||
|
help='标注坐标点为地理坐标'
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--pixellabel',
|
||||||
|
action='store_const',
|
||||||
|
const='pixellabel',
|
||||||
|
dest='method',
|
||||||
|
help='标注坐标系统为输入影像的像空间坐标'
|
||||||
|
)
|
||||||
|
parser.set_defaults(method='geolabel')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = getParams()
|
||||||
|
inbinfile=parser.inbinfile
|
||||||
|
labelfilepath=parser.labelfilepath
|
||||||
|
outfolderpath=parser.outfolderpath
|
||||||
|
methodstr= parser.method
|
||||||
|
print('inbinfile=',inbinfile)
|
||||||
|
print('labelfilepath=',labelfilepath)
|
||||||
|
print('outfolderpath=',outfolderpath)
|
||||||
|
print('methodstr=',methodstr)
|
||||||
|
DataSampleSliceRasterProcess(inbinfile, labelfilepath, outfolderpath,methodstr)
|
||||||
|
print("样本切分完成")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
"""
|
||||||
|
Dota数据集标注
|
||||||
|
每条标注数据应包含:
|
||||||
|
+前 8 个数值,表示目标的四个角点的坐标(x1, y1, x2, y2, x3, y3, x4, y4),按顺时针或逆时针顺序排列,无需进行归一化处理;
|
||||||
|
+倒数第二列为标注的目标类别名称,如:jun_ship等;
|
||||||
|
+最后一列为识别目标的难易程度(difficulty)
|
||||||
|
|
||||||
|
提供数据集标注
|
||||||
|
|
||||||
|
"""
|
||||||
|
class DotaObj(object):
|
||||||
|
def __init__(self, x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty):
|
||||||
|
self.x1=x1
|
||||||
|
self.y1=y1
|
||||||
|
self.x2=x2
|
||||||
|
self.y2=y2
|
||||||
|
self.x3=x3
|
||||||
|
self.y3=y3
|
||||||
|
self.x4=x4
|
||||||
|
self.y4=y4
|
||||||
|
self.clsname=clsname
|
||||||
|
self.difficulty=difficulty
|
||||||
|
def __str__(self):
|
||||||
|
return "{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}".format(
|
||||||
|
self.x1,self.y1,
|
||||||
|
self.x2,self.y2,
|
||||||
|
self.x3,self.y3,
|
||||||
|
self.x4,self.y4,
|
||||||
|
self.clsname,self.difficulty
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def createDota(x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty):
|
||||||
|
return DotaObj(x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty) # 8+2
|
||||||
|
|
||||||
|
|
||||||
|
def readDotaFile(dotafilepath):
|
||||||
|
content=None
|
||||||
|
with open(dotafilepath,'r',encoding="utf-8") as fp:
|
||||||
|
content=fp.read()
|
||||||
|
contentlines=content.split("\n")
|
||||||
|
# 逐行分解
|
||||||
|
result=[]
|
||||||
|
for linestr in contentlines:
|
||||||
|
linestr=linestr.replace("\t"," ")
|
||||||
|
linestr=linestr.replace(" "," ")
|
||||||
|
linemetas=linestr.split(" ")
|
||||||
|
if(len(linemetas)>=10):
|
||||||
|
x1=float(linemetas[0])
|
||||||
|
y1=float(linemetas[1])
|
||||||
|
x2=float(linemetas[2])
|
||||||
|
y2=float(linemetas[3])
|
||||||
|
x3=float(linemetas[4])
|
||||||
|
y3=float(linemetas[5])
|
||||||
|
x4=float(linemetas[6])
|
||||||
|
y4=float(linemetas[7])
|
||||||
|
clsname=linemetas[8]
|
||||||
|
difficulty=linemetas[9]
|
||||||
|
result.append(createDota(x1, y1, x2, y2, x3, y3, x4, y4, clsname,difficulty))
|
||||||
|
else:
|
||||||
|
print("parse result: ", linestr)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def writerDotaFile(dotalist,filepath):
|
||||||
|
with open(filepath,'a',encoding="utf-8") as fp:
|
||||||
|
for dota in dotalist:
|
||||||
|
if isinstance(dota,DotaObj):
|
||||||
|
fp.write("{}\n".format(str(dota)))
|
||||||
|
else:
|
||||||
|
fp.write("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}\n".format(
|
||||||
|
dota[0],dota[1],dota[2],dota[3],dota[4],dota[5],dota[6],dota[7],dota[8],dota[9]
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
from osgeo import ogr, gdal
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def shp_to_geojson(shp_path, geojson_path):
|
||||||
|
# 设置编码选项
|
||||||
|
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
|
||||||
|
gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
|
||||||
|
|
||||||
|
# 打开Shapefile
|
||||||
|
src_ds = ogr.Open(shp_path)
|
||||||
|
src_layer = src_ds.GetLayer(0)
|
||||||
|
|
||||||
|
# 创建输出GeoJSON
|
||||||
|
driver = ogr.GetDriverByName('GeoJSON')
|
||||||
|
if os.path.exists(geojson_path):
|
||||||
|
driver.DeleteDataSource(geojson_path)
|
||||||
|
dst_ds = driver.CreateDataSource(geojson_path)
|
||||||
|
dst_layer = dst_ds.CreateLayer('output', src_layer.GetSpatialRef())
|
||||||
|
|
||||||
|
# 复制字段定义
|
||||||
|
dst_layer.CreateFields(src_layer.schema)
|
||||||
|
|
||||||
|
# 复制要素
|
||||||
|
for feature in src_layer:
|
||||||
|
dst_feature = ogr.Feature(dst_layer.GetLayerDefn())
|
||||||
|
dst_feature.SetGeometry(feature.GetGeometryRef())
|
||||||
|
for j in range(feature.GetFieldCount()):
|
||||||
|
dst_feature.SetField(j, feature.GetField(j))
|
||||||
|
dst_layer.CreateFeature(dst_feature)
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
dst_ds = None
|
||||||
|
src_ds = None
|
||||||
|
print("转换完成")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
shp_to_geojson('input.shp', 'output.geojson')
|
||||||
|
|
@ -0,0 +1,220 @@
|
||||||
|
from osgeo import ogr, gdal
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def get_filename_without_ext(path):
|
||||||
|
base_name = os.path.basename(path)
|
||||||
|
if '.' not in base_name or base_name.startswith('.'):
|
||||||
|
return base_name
|
||||||
|
return base_name.rsplit('.', 1)[0]
|
||||||
|
|
||||||
|
def read_tif(path):
|
||||||
|
dataset = gdal.Open(path) # 打开TIF文件
|
||||||
|
if dataset is None:
|
||||||
|
print("无法打开文件")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
cols = dataset.RasterXSize # 图像宽度
|
||||||
|
rows = dataset.RasterYSize # 图像高度
|
||||||
|
bands = dataset.RasterCount
|
||||||
|
im_proj = dataset.GetProjection() # 获取投影信息
|
||||||
|
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
|
||||||
|
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
|
||||||
|
print("行数:", rows)
|
||||||
|
print("列数:", cols)
|
||||||
|
print("波段:", bands)
|
||||||
|
del dataset # 关闭数据集
|
||||||
|
return im_proj, im_Geotrans, im_data
|
||||||
|
|
||||||
|
def write_envi(im_data, im_geotrans, im_proj, output_path):
|
||||||
|
"""
|
||||||
|
将数组数据写入ENVI格式文件
|
||||||
|
:param im_data: 输入的numpy数组(2D或3D)
|
||||||
|
:param im_geotrans: 仿射变换参数(6元组)
|
||||||
|
:param im_proj: 投影信息(WKT字符串)
|
||||||
|
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr)
|
||||||
|
"""
|
||||||
|
im_bands = 1
|
||||||
|
im_height, im_width = im_data.shape
|
||||||
|
# 创建ENVI格式驱动
|
||||||
|
driver = gdal.GetDriverByName("ENVI")
|
||||||
|
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
|
||||||
|
|
||||||
|
if dataset is not None:
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 设置投影
|
||||||
|
|
||||||
|
dataset.GetRasterBand(1).WriteArray(im_data)
|
||||||
|
|
||||||
|
dataset.FlushCache() # 确保数据写入磁盘
|
||||||
|
dataset = None # 关闭文件
|
||||||
|
|
||||||
|
|
||||||
|
def Strech_linear(im_data):
|
||||||
|
im_data_dB=10*np.log10(im_data)
|
||||||
|
immask=np.isfinite(im_data_dB)
|
||||||
|
infmask = np.isinf(im_data_dB)
|
||||||
|
imvail_data=im_data[immask]
|
||||||
|
im_data_dB=0
|
||||||
|
|
||||||
|
minvalue=np.nanmin(imvail_data)
|
||||||
|
maxvalue=np.nanmax(imvail_data)
|
||||||
|
|
||||||
|
infmask = np.isinf(im_data_dB)
|
||||||
|
im_data[infmask] = minvalue-100
|
||||||
|
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254+1
|
||||||
|
im_data=np.clip(im_data,0,255)
|
||||||
|
return im_data.astype(np.uint8)
|
||||||
|
|
||||||
|
def Strech_linear1(im_data):
|
||||||
|
im_data_dB = 10 * np.log10(im_data)
|
||||||
|
immask = np.isfinite(im_data_dB)
|
||||||
|
infmask = np.isinf(im_data_dB)
|
||||||
|
imvail_data = im_data[immask]
|
||||||
|
im_data_dB=0
|
||||||
|
|
||||||
|
minvalue=np.percentile(imvail_data,1)
|
||||||
|
maxvalue = np.percentile(imvail_data, 99)
|
||||||
|
|
||||||
|
|
||||||
|
im_data[infmask] = minvalue - 100
|
||||||
|
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
|
||||||
|
im_data = np.clip(im_data, 0, 255)
|
||||||
|
|
||||||
|
return im_data.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def Strech_linear2(im_data):
|
||||||
|
im_data_dB = 10 * np.log10(im_data)
|
||||||
|
immask = np.isfinite(im_data_dB)
|
||||||
|
infmask = np.isinf(im_data_dB)
|
||||||
|
imvail_data = im_data[immask]
|
||||||
|
im_data_dB = 0
|
||||||
|
|
||||||
|
minvalue = np.percentile(imvail_data, 2)
|
||||||
|
maxvalue = np.percentile(imvail_data, 98)
|
||||||
|
|
||||||
|
im_data[infmask] = minvalue - 100
|
||||||
|
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
|
||||||
|
im_data = np.clip(im_data, 0, 255)
|
||||||
|
|
||||||
|
return im_data.astype(np.uint8)
|
||||||
|
|
||||||
|
def Strech_linear5(im_data):
|
||||||
|
im_data_dB = 10 * np.log10(im_data)
|
||||||
|
immask = np.isfinite(im_data_dB)
|
||||||
|
infmask = np.isinf(im_data_dB)
|
||||||
|
imvail_data = im_data[immask]
|
||||||
|
im_data_dB = 0
|
||||||
|
|
||||||
|
minvalue = np.percentile(imvail_data, 5)
|
||||||
|
maxvalue = np.percentile(imvail_data, 95)
|
||||||
|
|
||||||
|
im_data[infmask] = minvalue - 100
|
||||||
|
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
|
||||||
|
im_data = np.clip(im_data, 0, 255)
|
||||||
|
|
||||||
|
return im_data.astype(np.uint8)
|
||||||
|
|
||||||
|
def Strech_SquareRoot(im_data):
|
||||||
|
im_data=np.sqrt(im_data)
|
||||||
|
immask = np.isfinite(im_data)
|
||||||
|
imvail_data = im_data[immask]
|
||||||
|
|
||||||
|
minvalue=np.nanmin(imvail_data)
|
||||||
|
maxvalue=np.nanmax(imvail_data)
|
||||||
|
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
|
||||||
|
im_data = np.clip(im_data, 0, 255)
|
||||||
|
return im_data.astype(np.uint8)
|
||||||
|
|
||||||
|
def DataStrech(im_data,strechmethod):
|
||||||
|
# [,"Linear1","Linear2","Linear5","SquareRoot"]
|
||||||
|
if strechmethod == "Linear" :
|
||||||
|
return Strech_linear(im_data)
|
||||||
|
elif strechmethod == "Linear1":
|
||||||
|
return Strech_linear1(im_data)
|
||||||
|
elif strechmethod == "Linear2":
|
||||||
|
return Strech_linear2(im_data)
|
||||||
|
elif strechmethod == "Linear5":
|
||||||
|
return Strech_linear5(im_data)
|
||||||
|
elif strechmethod == "SquareRoot":
|
||||||
|
return Strech_SquareRoot(im_data)
|
||||||
|
else:
|
||||||
|
return im_data.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def stretchProcess(infilepath,outfilepath,strechmethod):
|
||||||
|
im_proj, im_Geotrans, im_data=read_tif(infilepath)
|
||||||
|
envifilepath=get_filename_without_ext(outfilepath)+".bin"
|
||||||
|
envifilepath=os.path.join(os.path.dirname(outfilepath),envifilepath)
|
||||||
|
im_data = DataStrech(im_data,strechmethod)
|
||||||
|
im_data = im_data.astype(np.uint8)
|
||||||
|
write_envi(im_data,im_Geotrans,im_proj,envifilepath)
|
||||||
|
Image.fromarray(im_data).save(outfilepath,compress_level=0)
|
||||||
|
print("图像拉伸处理结束")
|
||||||
|
|
||||||
|
def getParams():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-i','--infile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.tiff", help='输入shapefile文件')
|
||||||
|
parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.png", help='输出geojson文件')
|
||||||
|
group = parser.add_mutually_exclusive_group()
|
||||||
|
group.add_argument(
|
||||||
|
'--Linear',
|
||||||
|
action='store_const',
|
||||||
|
const='Linear',
|
||||||
|
dest='method',
|
||||||
|
help='线性拉伸'
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--Linear1prec',
|
||||||
|
action='store_const',
|
||||||
|
const='Linear1',
|
||||||
|
dest='method',
|
||||||
|
help='1%线性拉伸'
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--Linear2prec',
|
||||||
|
action='store_const',
|
||||||
|
const='Linear2',
|
||||||
|
dest='method',
|
||||||
|
help='2%线性拉伸'
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--Linear5prec',
|
||||||
|
action='store_const',
|
||||||
|
const='Linear5',
|
||||||
|
dest='method',
|
||||||
|
help='5%线性拉伸'
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
'--SquareRoot',
|
||||||
|
action='store_const',
|
||||||
|
const='SquareRoot',
|
||||||
|
dest='method',
|
||||||
|
help='平方根拉伸'
|
||||||
|
)
|
||||||
|
parser.set_defaults(method='SquareRoot')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = getParams()
|
||||||
|
intiffPath=parser.infile
|
||||||
|
outbinPath=parser.outfile
|
||||||
|
methodstr=parser.method
|
||||||
|
print('infile=',intiffPath)
|
||||||
|
print('outfile=',outbinPath)
|
||||||
|
print('method=',methodstr)
|
||||||
|
stretchProcess(intiffPath, outbinPath, methodstr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
from osgeo import ogr
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def shapefile_to_dota(shp_path, output_path, class_field='class', difficulty_value=1):
|
||||||
|
"""
|
||||||
|
将Shapefile转换为DOTA格式
|
||||||
|
:param shp_path: Shapefile文件路径
|
||||||
|
:param output_path: 输出目录
|
||||||
|
:param class_field: 类别字段名
|
||||||
|
:param difficulty_value: 难度默认字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 注册所有驱动
|
||||||
|
ogr.RegisterAll()
|
||||||
|
|
||||||
|
# 打开Shapefile文件
|
||||||
|
driver = ogr.GetDriverByName('ESRI Shapefile')
|
||||||
|
datasource = driver.Open(shp_path, 0)
|
||||||
|
if datasource is None:
|
||||||
|
print("无法打开Shapefile文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取图层
|
||||||
|
layer = datasource.GetLayer()
|
||||||
|
|
||||||
|
output_file = output_path
|
||||||
|
|
||||||
|
with open(output_file, 'w',encoding="utf-8") as f:
|
||||||
|
# 写入DOTA格式头信息(可选)
|
||||||
|
# f.write('imagesource:unknown\n')
|
||||||
|
# f.write('gsd:1.0\n')
|
||||||
|
|
||||||
|
# 遍历所有要素
|
||||||
|
for feature in layer:
|
||||||
|
# 获取几何对象
|
||||||
|
geom = feature.GetGeometryRef()
|
||||||
|
|
||||||
|
# 获取类别和难度
|
||||||
|
class_name = feature.GetField(class_field) if feature.GetField(class_field) else 'unknown'
|
||||||
|
difficulty = difficulty_value
|
||||||
|
|
||||||
|
# 处理不同类型的几何图形
|
||||||
|
if geom.GetGeometryName() == 'POLYGON':
|
||||||
|
# 获取多边形外环
|
||||||
|
ring = geom.GetGeometryRef(0)
|
||||||
|
# 获取所有点
|
||||||
|
points = []
|
||||||
|
for i in range(ring.GetPointCount()):
|
||||||
|
points.append(ring.GetPoint(i))
|
||||||
|
|
||||||
|
# 确保有足够的点(至少4个)
|
||||||
|
if len(points) >= 4:
|
||||||
|
# 取前4个点作为DOTA格式的四个角点
|
||||||
|
# 注意: DOTA要求按顺序排列(顺时针或逆时针)
|
||||||
|
x1, y1 = points[0][0], points[0][1]
|
||||||
|
x2, y2 = points[1][0], points[1][1]
|
||||||
|
x3, y3 = points[2][0], points[2][1]
|
||||||
|
x4, y4 = points[3][0], points[3][1]
|
||||||
|
|
||||||
|
# 写入DOTA格式行
|
||||||
|
line = f"{x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4} {class_name} {difficulty}\n"
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
|
# 释放资源
|
||||||
|
datasource.Destroy()
|
||||||
|
print("转换完毕")
|
||||||
|
|
||||||
|
def getParams():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-i','--infile',type=str,default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.shp', help='输入shapefile文件')
|
||||||
|
parser.add_argument('-o', '--outfile',type=str,default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.txt', help='输出geojson文件')
|
||||||
|
parser.add_argument('-c', '--clsname',type=str,default=r'Label', help='输出geojson文件')
|
||||||
|
parser.add_argument('-d', '--difficulty',type=int,default=1, help='输出geojson文件')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = getParams()
|
||||||
|
inFilePath=parser.infile
|
||||||
|
outpath=parser.outfile
|
||||||
|
clsname=parser.clsname
|
||||||
|
difficulty=parser.difficulty
|
||||||
|
print('infile=',inFilePath)
|
||||||
|
print('outfile=',outpath)
|
||||||
|
print('clsname=',clsname)
|
||||||
|
print('difficulty=',difficulty)
|
||||||
|
shapefile_to_dota(inFilePath, outpath, clsname, difficulty)
|
||||||
|
# 使用示例
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
from osgeo import ogr, gdal
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def shp_to_geojson(shp_path, geojson_path):
|
||||||
|
# 设置编码选项
|
||||||
|
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
|
||||||
|
gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
|
||||||
|
|
||||||
|
# 打开Shapefile
|
||||||
|
src_ds = ogr.Open(shp_path)
|
||||||
|
src_layer = src_ds.GetLayer(0)
|
||||||
|
|
||||||
|
# 创建输出GeoJSON
|
||||||
|
driver = ogr.GetDriverByName('GeoJSON')
|
||||||
|
if os.path.exists(geojson_path):
|
||||||
|
driver.DeleteDataSource(geojson_path)
|
||||||
|
dst_ds = driver.CreateDataSource(geojson_path)
|
||||||
|
dst_layer = dst_ds.CreateLayer('output', src_layer.GetSpatialRef())
|
||||||
|
|
||||||
|
# 复制字段定义
|
||||||
|
dst_layer.CreateFields(src_layer.schema)
|
||||||
|
|
||||||
|
# 复制要素
|
||||||
|
for feature in src_layer:
|
||||||
|
dst_feature = ogr.Feature(dst_layer.GetLayerDefn())
|
||||||
|
dst_feature.SetGeometry(feature.GetGeometryRef())
|
||||||
|
for j in range(feature.GetFieldCount()):
|
||||||
|
dst_feature.SetField(j, feature.GetField(j))
|
||||||
|
dst_layer.CreateFeature(dst_feature)
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
dst_ds = None
|
||||||
|
src_ds = None
|
||||||
|
print("转换完成")
|
||||||
|
|
||||||
|
def getParams():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-i','--infile',default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.shp', help='输入shapefile文件')
|
||||||
|
parser.add_argument('-o', '--outfile',default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.json', help='输出geojson文件')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = getParams()
|
||||||
|
inFilePath=parser.infile
|
||||||
|
outpath=parser.outfile
|
||||||
|
print('infile=',inFilePath)
|
||||||
|
print('outfile=',outpath)
|
||||||
|
shp_to_geojson(inFilePath, outpath)
|
||||||
Loading…
Reference in New Issue