SpacetySliceTools/tools/DataSampleSliceTrainDataset.py

437 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import argparse
from osgeo import ogr,gdal
from matplotlib import pyplot as plt
from osgeo import gdal
import matplotlib
import matplotlib.patches as patches
from osgeo import gdal
from PIL import Image
from scipy.spatial import cKDTree
import numpy as np
from DotaOperator import DotaObj,createDota,readDotaFile,writerDotaFile
import argparse
import math
from math import ceil, floor
import random
##########################################################################
# 函数区
##########################################################################
def read_tif(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
del dataset # 关闭数据集
return im_proj, im_Geotrans, im_data
def write_envi(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("ENVI")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
outfilepath=output_path.replace(".bin",".png")
Image.fromarray(im_data).save(outfilepath, compress_level=0)
def geoXY2pixelXY(geo_x, geo_y, inv_gt):
pixel_x = inv_gt[0] + geo_x * inv_gt[1] + geo_y * inv_gt[2]
pixel_y = inv_gt[3] + geo_x * inv_gt[4] + geo_y * inv_gt[5]
return pixel_x, pixel_y
def label2pixelpoints(dotapath,tiff_inv_trans,methodstr,filterlabels):
dotalist = readDotaFile(dotapath,filterlabels)
if methodstr=="geolabel":
for i in range(len(dotalist)):
geo_x = dotalist[i].x1 # x1
geo_y = dotalist[i].y1
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x1 = pixel_x
dotalist[i].y1 = pixel_y
geo_x = dotalist[i].x2 # x2
geo_y = dotalist[i].y2
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x2 = pixel_x
dotalist[i].y2 = pixel_y
geo_x = dotalist[i].x3 # x3
geo_y = dotalist[i].y3
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x3 = pixel_x
dotalist[i].y3 = pixel_y
geo_x = dotalist[i].x4 # x4
geo_y = dotalist[i].y4
pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans)
dotalist[i].x4 = pixel_x
dotalist[i].y4 = pixel_y
print("点数:", len(dotalist))
return dotalist
def getMaxEdge(dotalist, ids):
cornpoint = np.zeros((len(ids) * 4, 2))
for idx in range(len(ids)):
cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1
cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2
cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3
cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4
cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1
cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2
cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3
cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4
xedge = np.max(cornpoint[:, 0]) - np.min(cornpoint[:, 0])
yedge = np.max(cornpoint[:, 1]) - np.min(cornpoint[:, 1])
edgelen = xedge if xedge > yedge else yedge
return edgelen
def getExternCenter(dotalist, ids):
cornpoint = np.zeros((len(ids) * 4, 2))
for idx in range(len(ids)):
cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1
cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2
cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3
cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4
cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1
cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2
cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3
cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4
minX = np.min(cornpoint[:, 0])
minY = np.min(cornpoint[:, 1])
maxX = np.max(cornpoint[:, 0])
maxY = np.max(cornpoint[:, 1])
centerX = (minX + maxX) / 2
centerY = (minY + maxY) / 2
return [centerX, centerY, minX, minY, maxX, maxY]
def drawSliceRasterPrivew(tiff_data,dotalist,clusterDict):
# 绘制图形
# 创建图形和坐标轴
fig, ax = plt.subplots(figsize=(20, 16))
ax.imshow(tiff_data, cmap='gray')
# 绘制每个目标的矩形框并标注坐标
for i in range(len(dotalist)):
# 提取x和y坐标
x_coords = [dotalist[i].x1, dotalist[i].x2, dotalist[i].x3, dotalist[i].x4]
y_coords = [dotalist[i].y1, dotalist[i].y2, dotalist[i].y3, dotalist[i].y4]
# 计算最小外接矩形AABB
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
width = x_max - x_min
height = y_max - y_min
# 绘制无填充矩形框(仅红色边框)
rect = patches.Rectangle(
(x_min, y_min), width, height,
linewidth=2, edgecolor='red', facecolor='none' # 关键facecolor='none'
)
ax.add_patch(rect)
# ax.annotate(f'({x},{y})', xy=(x, y), xytext=(5, 5),
# textcoords='offset points', fontsize=10,
# bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8))
# 在矩形中心标注目标编号
center_x = sum(x_coords) / 4
center_y = sum(y_coords) / 4
ax.text(center_x, center_y, str(i),
ha='center', va='center', fontsize=6, color='red')
# 以类别中心为中心绘制四边形
for k in clusterDict:
# 绘制无填充矩形框(仅红色边框)
minX = clusterDict[k]["p"][0]
minY = clusterDict[k]["p"][1]
rect = patches.Rectangle(
(minX , minY), 1024, 1024,
linewidth=2, edgecolor='green', facecolor='none' # 关键facecolor='none'
)
ax.add_patch(rect)
ax.text(minX+512, minY+512, str(k),
ha='center', va='center', fontsize=6, color='green')
plt.tight_layout()
plt.show()
print("绘图结束")
return None
def check_B_in_A(A,B):
"""
判断A包含B
:param A: [x0,y0.w.h]
:param B: [x0,y0.w.h]
:return:
"""
# 解构矩形A和B的参数
Ax0, Ay0, Aw, Ah = A
Bx0, By0, Bw, Bh = B
# 计算矩形A和B的右边界和下边界
Ax1 = Ax0 + Aw
Ay1 = Ay0 + Ah
Bx1 = Bx0 + Bw
By1 = By0 + Bh
# 判断B是否完全在A内部
return (Bx0 >= Ax0) and (Bx1 <= Ax1) and (By0 >= Ay0) and (By1 <= Ay1)
##########################################################################
# 切分算法流程图
##########################################################################
def getclusterDict(dotalist,imgheight,imgwidth,pitchSize=1024,max_overlap_rate=0.2):
"""
生成切片数据
:param dotalist: 样本集
:param imgheight: 图像高度
:param imgwidth: 图像宽度
:return: 切片类型
"""
clusterDict = {} # clusterDict[i]={"p":[sx,sy],"id":[]}
for did in range(len(dotalist)):
centerX, centerY, minX, minY, maxX, maxY=getExternCenter(dotalist, [did])
# 构建随机平移
EdgeX=maxX-minX
EdgeY=maxY-minY
random_numX = random.random()*2-1
random_numY = random.random()*2-1
offsetX=int(math.floor((pitchSize/2-EdgeX)*random_numX)) if pitchSize/2 > EdgeX else 0
offsetY=int(math.floor((pitchSize/2-EdgeY)*random_numY)) if pitchSize/2 > EdgeY else 0
offsetX= offsetX if centerX+offsetX+pitchSize/2 < imgwidth else math.ceil(imgwidth-pitchSize/2-centerX-1)
offsetY= offsetY if centerY+offsetY+pitchSize/2 < imgheight else math.ceil(imgheight-pitchSize/2-centerY-1)
offsetX= math.floor(pitchSize/2-centerX) if centerX+offsetX-pitchSize/2 < 0 else offsetX
offsetY= math.floor(pitchSize/2-centerY) if centerY+offsetY-pitchSize/2 < 0 else offsetY
minX=int(centerX+offsetX-pitchSize/2)
maxY=int(centerY+offsetY-pitchSize/2)
clusterDict[did]={"p":[minX,maxY],
"id":[did]}
for cid in clusterDict:
[minX,minY]=clusterDict[cid]["p"]
Abox=[minX,minY,pitchSize,pitchSize]
for did in range(len(dotalist)):
[centerX, centerY, minX, minY, maxX, maxY] = getExternCenter(dotalist, [did])
EdgeX = maxX - minX
EdgeY = maxY - minY
Bbox=[minX, minY,EdgeX,EdgeY]
if check_B_in_A(Abox,Bbox):
if did not in clusterDict[cid]["id"]:
clusterDict[cid]["id"].append(did)
return clusterDict
def drawSlictplot(clusterDict,dotalist,tiff_data,nrows=10,ncols=9):
"""
:param clusterDict: clusterDict[i]={"p":[sx,sy],"id":[]}
:param dotalist: x1, y1, x2, y2, x3, y3, x4, y4 clsname diffcule
:return:
"""
fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(20, 16))
plt.tight_layout(pad=3.0)
# 9*10
subid=0
for cid in clusterDict:
sx,sy=clusterDict[cid]["p"]
colid=subid//nrows
rowid=subid%nrows
subid=subid+1
ax = axes[rowid, colid]
ax.set_title(str(cid))
sliceData=tiff_data[sy:(sy+1024),sx:(sx+1024)]
ax.imshow(sliceData, cmap='gray')
for did in clusterDict[cid]["id"] :
# 提取x和y坐标
x_coords = [dotalist[did].x1-sx, dotalist[did].x2-sx, dotalist[did].x3-sx, dotalist[did].x4-sx]
y_coords = [dotalist[did].y1-sy, dotalist[did].y2-sy, dotalist[did].y3-sy, dotalist[did].y4-sy]
# 计算最小外接矩形AABB
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
width = x_max - x_min
height = y_max - y_min
# 绘制无填充矩形框(仅红色边框)
rect = patches.Rectangle(
(x_min, y_min), width, height,
linewidth=2, edgecolor='red', facecolor='none' # 关键facecolor='none'
)
ax.add_patch(rect)
# 在矩形中心标注目标编号
center_x = x_min+width/2
center_y = y_min+height/2
ax.text(center_x, center_y, str(did),
ha='center', va='center', fontsize=6, color='red')
plt.tight_layout()
plt.show()
print("绘图结束")
return None
def slictDataAndOutlabel(clusterDict,dotalist,tiff_data,tiff_basename,outfolderpath,im_geotrans, im_proj):
"""
切分标签,输出结果与文件
:param clusterDict:
:param dotalist:
:param tiff_data:
:param tiff_name:
:param outfolderpath:
:return:
"""
for cid in clusterDict:
sx, sy = clusterDict[cid]["p"]
sliceData = tiff_data[sy:(sy + 1024), sx:(sx + 1024)]
outbinname="{}_{}.bin".format(tiff_basename,str(cid).zfill(3))
outlabelname="{}_{}.txt".format(tiff_basename,str(cid).zfill(3))
# 获取样本列表
outdotalist=[]
for did in clusterDict[cid]["id"] :
x1=dotalist[did].x1-sx
x2=dotalist[did].x2-sx
x3=dotalist[did].x3-sx
x4=dotalist[did].x4-sx
y1=dotalist[did].y1-sy
y2=dotalist[did].y2-sy
y3=dotalist[did].y3-sy
y4=dotalist[did].y4-sy
tempdota=createDota(x1,y1,x2,y2,x3,y3,x4,y4,dotalist[did].clsname,dotalist[did].difficulty)
outdotalist.append(tempdota)
outlabelpath=os.path.join(outfolderpath,outlabelname)
outbinpath=os.path.join(outfolderpath,outbinname)
temp_im_geotrans=[tempi for tempi in im_geotrans]
# 处理 0,3
temp_im_geotrans[0]=im_geotrans[0]+sx*im_geotrans[1]+im_geotrans[2]*sy # x
temp_im_geotrans[3]=im_geotrans[3]+sx*im_geotrans[4]+im_geotrans[5]*sy # y
write_envi(sliceData,temp_im_geotrans,im_proj,outbinpath)
writerDotaFile(outdotalist,outlabelpath)
##########################################################################
# 处理流程图
##########################################################################
def DataSampleSliceRasterProcess(inbinfile,labelfilepath,outfolderpath,methodstr,filterlabels):
tiff_proj, tiff_trans, tiff_data = read_tif(inbinfile)
tiff_inv_trans = gdal.InvGeoTransform(tiff_trans)
dotalist=label2pixelpoints(labelfilepath,tiff_inv_trans,methodstr,filterlabels) # 获取坐标
imgheight, imgwidth=tiff_data.shape
clusterDict=getclusterDict(dotalist,imgheight,imgwidth)
# drawSliceRasterPrivew(tiff_data, dotalist, clusterDict)
ncols=int(len(clusterDict)/9+1)
# drawSlictplot(clusterDict, dotalist, tiff_data, 9, ncols)
tiff_name=os.path.basename(inbinfile)
tiff_basename=os.path.splitext(tiff_name)[0]
slictDataAndOutlabel(clusterDict, dotalist, tiff_data, tiff_basename, outfolderpath, tiff_trans, tiff_proj)
pass
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--inbinfile',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.bin', help='输入tiff的bin文件')
parser.add_argument('-l', '--labelfilepath',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\标注\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01_LC.txt", help='输入标注')
parser.add_argument('-o', '--outfolderpath',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\切片结果', help='切片文件夹地址')
parser.add_argument('-f', '--filterlabel',type=str,default=r'JLC;MLC', help='标签过滤')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--geolabel',
action='store_const',
const='geolabel',
dest='method',
help='标注坐标点为地理坐标'
)
group.add_argument(
'--pixellabel',
action='store_const',
const='pixellabel',
dest='method',
help='标注坐标系统为输入影像的像空间坐标'
)
parser.set_defaults(method='geolabel')
args = parser.parse_args()
return args
if __name__ == '__main__':
try:
parser = getParams()
inbinfile=parser.inbinfile
labelfilepath=parser.labelfilepath
outfolderpath=parser.outfolderpath
methodstr= parser.method
filterlabels=parser.filterlabel.strip().split(';')
print('inbinfile=',inbinfile)
print('labelfilepath=',labelfilepath)
print('outfolderpath=',outfolderpath)
print('methodstr=',methodstr)
print('filterlabels=',filterlabels)
DataSampleSliceRasterProcess(inbinfile, labelfilepath, outfolderpath,methodstr,filterlabels)
print("样本切分完成")
exit(2)
except Exception as e:
print(e)
exit(3)