1. 增加专门针对港口的 shapefile 2 dota的程序

master
陈增辉 2025-09-26 09:52:26 +08:00
parent 83f3d5750b
commit 0b64b18ef0
1 changed files with 446 additions and 0 deletions

View File

@ -0,0 +1,446 @@
from osgeo import ogr
import os
import argparse
from osgeo import ogr
import os
import argparse
from osgeo import ogr, gdal
import os
import argparse
import numpy as np
from scipy.spatial import KDTree
from tools.DotaOperator import DotaObj,readDotaFile,writerDotaFile,createDota
from glob import glob
from pathlib import Path
import shutil
MLCName="MLC" # M
JLCName="JLC" # J
MJLCName="MJLC" # JM 混合
NOLCName="NOLC" # 没有港口
def find_tif_files_pathlib(directory):
path = Path(directory)
# 使用rglob递归匹配所有.tif和.tiff文件
tif_files = list(path.rglob('*.tiff'))+list(path.rglob('*.tif'))
# 将Path对象转换为字符串路径
return [str(file) for file in tif_files]
def find_srcPath(srcFolder):
root_path = Path(srcFolder)
target_path = [folderpath for folderpath in root_path.rglob("*") if folderpath.is_dir() and folderpath.name=="0-原图"]
tiff_files = []
for folderpath in target_path:
tiff_files=tiff_files+find_tif_files_pathlib(folderpath)
tiff_dict={}
for filepath in tiff_files:
rootname=Path(filepath).stem
tiff_dict[rootname]=filepath
return tiff_dict
def read_tifInfo(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件,读取文件信息")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
# im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
x1=im_Geotrans[0]+im_Geotrans[1]*0
x2=im_Geotrans[0]+im_Geotrans[1]*cols
y1=im_Geotrans[3]+im_Geotrans[5]*0
y2=im_Geotrans[3]+im_Geotrans[5]*rows
xmin=min(x1,x2)
xmax=max(x1,x2)
ymin=min(y1,y2)
ymax=max(y1,y2)
geoExtend=[xmin,ymin,xmax,ymax]
del dataset # 关闭数据集
return im_proj, im_Geotrans,geoExtend
def getshapefileInfo(shp_path):
"""
将Shapefile转换为DOTA格式
:param shp_path: Shapefile文件路径
"""
geom_points=[]
print("shapefile: ",shp_path)
# 注册所有驱动
ogr.RegisterAll()
# 打开Shapefile文件
driver = ogr.GetDriverByName('ESRI Shapefile')
datasource = driver.Open(shp_path, 0)
if datasource is None:
print("无法打开Shapefile文件")
return
print("layer count: ",datasource.GetLayerCount())
for layerid in range(datasource.GetLayerCount()):
print("layer id: ",layerid)
# 获取图层
layer = datasource.GetLayer(layerid)
layer_defn=layer.GetLayerDefn()
field_count=layer_defn.GetFieldCount()
print("field_count:", field_count)
for i in range (field_count):
field_defn=layer_defn.GetFieldDefn(i)
field_name=field_defn.GetName()
field_type=field_defn.GetType()
field_type_name=field_defn.GetFieldTypeName(field_type)
print("field_name:", field_name, field_type_name, field_type_name)
for feature in layer:
geom = feature.GetGeometryRef()
if geom.GetGeometryName() == 'POINT':
x=geom.GetX()
y=geom.GetY()
geom_points.append([x,y])
return np.array(geom_points)
def getTiffsInfo(tiffnames,folderpath):
"""
获取所有影像的几何信息
Args:
tiff_paths: tiff列表
Returns:
"""
tiffdict={}
for tiff_name in tiffnames:
if tiff_name.endswith(".tiff"):
tiff_path=os.path.join(folderpath,tiff_name)
im_proj, im_Geotrans, geoExtend=read_tifInfo(tiff_path)
tiffdict[tiff_name]={"geoExtend":geoExtend,"geoTrans":im_Geotrans,"imProj":im_proj}
return tiffdict
def getMJSignal(tiffpath,shipPortTree,outfolderPath):
rootname=Path(tiffpath).stem
portTxtpath=os.path.join(outfolderPath,rootname+".txt")
im_proj, im_Geotrans, geoExtend = read_tifInfo(tiffpath) # geoExtend : [xmin,ymin,xmax,ymax]
[xmin, ymin, xmax, ymax]=geoExtend
center_x = (xmin + xmax) / 2.0
center_y = (ymin + ymax) / 2.0
center_point = [center_x, center_y]
# 2. 计算能够覆盖整个矩形区域的最小半径(中心点到任一角点的最大距离)
radius_to_corner = np.sqrt((xmax - center_x) ** 2 + (ymax - center_y) ** 2)
MLCFlag=False
JLCFlag=False
## MLC
if MLCName in shipPortTree and not shipPortTree[MLCName] is None:
# 3. 使用 query_ball_point 查找以中心点为圆心radius_to_corner 为半径的圆内的所有点的索引
potential_indices = shipPortTree[MLCName].query_ball_point(center_point, r=radius_to_corner)
# 4. 获取这些潜在点的实际坐标
# 假设你的 KDTree 是从 data_points 构建的MLCTree = KDTree(data_points)
potential_points = shipPortTree[MLCName].data[potential_indices] # 这是所有潜在点的坐标数组
# 5. 进行精确的矩形范围过滤
# 条件判断x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间
x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax)
y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax)
within_rect_indices_mask = x_in_range & y_in_range
# 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]
final_points = potential_points[within_rect_indices_mask]
# final_points 就是你要的矩形范围内的点
# 如果你需要的是这些点在原始数据中的索引,而不是坐标本身:
final_indices = np.array(potential_indices)[within_rect_indices_mask]
if final_points.shape[0]>0:
MLCFlag=True
with open(portTxtpath,"w",encoding="utf-8") as f:
for i in range(final_points.shape[0]):
f.write("{}\t\t{},{}\n".format("MLC",final_points[i,0],final_points[i,1]))
pass
if JLCName in shipPortTree and not shipPortTree[JLCName] is None:
# 3. 使用 query_ball_point 查找以中心点为圆心radius_to_corner 为半径的圆内的所有点的索引
potential_indices = shipPortTree[JLCName].query_ball_point(center_point, r=radius_to_corner)
# 4. 获取这些潜在点的实际坐标
# 假设你的 KDTree 是从 data_points 构建的MLCTree = KDTree(data_points)
potential_points = shipPortTree[JLCName].data[potential_indices] # 这是所有潜在点的坐标数组
# 5. 进行精确的矩形范围过滤
# 条件判断x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间
x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax)
y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax)
within_rect_indices_mask = x_in_range & y_in_range
# 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]
final_points = potential_points[within_rect_indices_mask]
# final_points 就是你要的矩形范围内的点
# 如果你需要的是这些点在原始数据中的索引,而不是坐标本身:
final_indices = np.array(potential_indices)[within_rect_indices_mask]
if final_points.shape[0]>0:
JLCFlag=True
with open(portTxtpath,"a",encoding="utf-8") as f:
for i in range(final_points.shape[0]):
f.write("{}\t\t{},{}\n".format("JLC",final_points[i,0],final_points[i,1]))
pass
# 处理软件
return MLCFlag,JLCFlag
def getTiffInPort(shipPortTree,srcFolderPath_0img,outTiffInfoFilePath,outfolderPath):
tiffpaths=find_tif_files_pathlib(srcFolderPath_0img)
tiffLCPort={
MLCName:[],
JLCName:[],
MJLCName:[],
NOLCName:[]
}
for tiffpath in tiffpaths:
MLCFlag,JLCFlag=getMJSignal(tiffpath,shipPortTree,outfolderPath)
if MLCFlag and JLCFlag:
tiffLCPort[MJLCName].append(tiffpath)
elif MLCFlag:
tiffLCPort[MLCName].append(tiffpath)
elif JLCFlag:
tiffLCPort[JLCName].append(tiffpath)
else:
tiffLCPort[NOLCName].append(tiffpath)
# 输出文件
with open(outTiffInfoFilePath,'w',encoding="utf-8") as f:
for k in tiffLCPort:
for tiffpath in tiffLCPort[k]:
f.write("{}\t\t{}\n".format(k,tiffpath))
def getMJSignal(geoExtend,shipPortTree):
[xmin, ymin, xmax, ymax]=geoExtend
center_x = (xmin + xmax) / 2.0
center_y = (ymin + ymax) / 2.0
center_point = [center_x, center_y]
# 2. 计算能够覆盖整个矩形区域的最小半径(中心点到任一角点的最大距离)
radius_to_corner = np.sqrt((xmax - center_x) ** 2 + (ymax - center_y) ** 2)
MLCFlag=False
JLCFlag=False
## MLC
if MLCName in shipPortTree and not shipPortTree[MLCName] is None:
# 3. 使用 query_ball_point 查找以中心点为圆心radius_to_corner 为半径的圆内的所有点的索引
potential_indices = shipPortTree[MLCName].query_ball_point(center_point, r=radius_to_corner)
# 4. 获取这些潜在点的实际坐标
# 假设你的 KDTree 是从 data_points 构建的MLCTree = KDTree(data_points)
potential_points = shipPortTree[MLCName].data[potential_indices] # 这是所有潜在点的坐标数组
# 5. 进行精确的矩形范围过滤
# 条件判断x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间
x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax)
y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax)
within_rect_indices_mask = x_in_range & y_in_range
# 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]
final_points = potential_points[within_rect_indices_mask]
# final_points 就是你要的矩形范围内的点
# 如果你需要的是这些点在原始数据中的索引,而不是坐标本身:
final_indices = np.array(potential_indices)[within_rect_indices_mask]
if final_points.shape[0]>0:
MLCFlag=True
# with open(portTxtpath,"w",encoding="utf-8") as f:
# for i in range(final_points.shape[0]):
# f.write("{}\t\t{},{}\n".format("MLC",final_points[i,0],final_points[i,1]))
# pass
if JLCName in shipPortTree and not shipPortTree[JLCName] is None:
# 3. 使用 query_ball_point 查找以中心点为圆心radius_to_corner 为半径的圆内的所有点的索引
potential_indices = shipPortTree[JLCName].query_ball_point(center_point, r=radius_to_corner)
# 4. 获取这些潜在点的实际坐标
# 假设你的 KDTree 是从 data_points 构建的MLCTree = KDTree(data_points)
potential_points = shipPortTree[JLCName].data[potential_indices] # 这是所有潜在点的坐标数组
# 5. 进行精确的矩形范围过滤
# 条件判断x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间
x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax)
y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax)
within_rect_indices_mask = x_in_range & y_in_range
# 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]
final_points = potential_points[within_rect_indices_mask]
# final_points 就是你要的矩形范围内的点
# 如果你需要的是这些点在原始数据中的索引,而不是坐标本身:
final_indices = np.array(potential_indices)[within_rect_indices_mask]
if final_points.shape[0]>0:
JLCFlag=True
# with open(portTxtpath,"a",encoding="utf-8") as f:
# for i in range(final_points.shape[0]):
# f.write("{}\t\t{},{}\n".format("JLC",final_points[i,0],final_points[i,1]))
# pass
# 处理软件
if MLCFlag and JLCFlag:
return MJLCName
# tiffLCPort[MJLCName].append(tiffpath)
elif MLCFlag:
return MLCName
# tiffLCPort[MLCName].append(tiffpath)
elif JLCFlag:
return JLCName
# tiffLCPort[JLCName].append(tiffpath)
else:
return NOLCName
# tiffLCPort[NOLCName].append(tiffpath)
# return MLCFlag,JLCFlag
def shapefile_to_dota(shp_path, output_path, shipPortTree,difficulty_value=1):
"""
将Shapefile转换为DOTA格式
:param shp_path: Shapefile文件路径
:param output_path: 输出目录
:param class_field: 类别字段名
:param difficulty_value: 难度默认字段
"""
# 注册所有驱动
ogr.RegisterAll()
# 打开Shapefile文件
driver = ogr.GetDriverByName('ESRI Shapefile')
datasource = driver.Open(shp_path, 0)
if datasource is None:
print("无法打开Shapefile文件")
return
# 获取图层
layer = datasource.GetLayer()
output_file = output_path
with open(output_file, 'w',encoding="utf-8") as f:
# 写入DOTA格式头信息(可选)
# f.write('imagesource:unknown\n')
# f.write('gsd:1.0\n')
# 遍历所有要素
for feature in layer:
# 获取几何对象
geom = feature.GetGeometryRef()
if geom is None:
continue
# 获取类别和难度
try:
class_name = 'unknown'
except Exception as e:
class_name="MLC"
print(e)
difficulty = difficulty_value
# 处理不同类型的几何图形
if geom.GetGeometryName() == 'POLYGON':
# 获取多边形外环
ring = geom.GetGeometryRef(0)
# 获取所有点
points = []
for i in range(ring.GetPointCount()):
points.append(ring.GetPoint(i))
# 确保有足够的点(至少4个)
if len(points) >= 4:
# 取前4个点作为DOTA格式的四个角点
# 注意: DOTA要求按顺序排列(顺时针或逆时针)
x1, y1 = points[0][0], points[0][1]
x2, y2 = points[1][0], points[1][1]
x3, y3 = points[2][0], points[2][1]
x4, y4 = points[3][0], points[3][1]
xmin = min(x1, x2,x3,x4)
xmax = max(x1, x2,x3,x4)
ymin = min(y1, y2,y3,y4)
ymax = max(y1, y2,y3,y4)
# [xmin, ymin, xmax, ymax] = geoExtend
geoExtend = [xmin, ymin, xmax, ymax]
class_name=getMJSignal(geoExtend, shipPortTree)
# 写入DOTA格式行
line = f"{x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4} {class_name} {difficulty}\n"
f.write(line)
# 释放资源
datasource.Destroy()
print("转换完毕")
def PortShapeProces(shp_path, output_path, MLCPath,JLCPath,JMLCPath,difficulty_value=1):
shipPort={
MLCName:getshapefileInfo(MLCPath),
JLCName:getshapefileInfo(JLCPath),
MJLCName:getshapefileInfo(JMLCPath), # 舰船不区分 居民一体
}
shipPortTree={
MLCName:KDTree(shipPort[MLCName]),
JLCName:KDTree(shipPort[JLCName]),
MJLCName:KDTree(shipPort[MJLCName]),
}
shapefile_to_dota(shp_path, output_path,shipPortTree, difficulty)
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--infile',type=str,default=r'D:\Annotation_Y\港口\聚束模式\20250505_sp\bc3-sp-org-vv-20250410t053930-020615-000034-005087-01_LC.shp', help='输入shapefile文件')
parser.add_argument('-o', '--outfile',type=str,default=r'D:\Annotation_Y\港口\聚束模式\20250505_sp\bc3-sp-org-vv-20250410t053930-020615-000034-005087-01_LC.txt', help='输出geojson文件')
parser.add_argument('-m', '--mLC',type=str,help=r'MLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\港口(民船).shp')
parser.add_argument('-j', '--jLC',type=str,help=r'JLC' ,default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军港.shp')
parser.add_argument('-jm', '--jmlc',type=str,help=r'JMLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军民一体港口.shp')
parser.add_argument('-d', '--difficulty',type=int,default=1, help='输出geojson文件')
args = parser.parse_args()
return args
if __name__ == '__main__':
try:
parser = getParams()
inFilePath=parser.infile
outpath=parser.outfile
mLCPath=parser.mLC
jLCPath=parser.jLC
jmLCPath=parser.jmlc
difficulty=parser.difficulty
print('infile=',inFilePath)
print('outfile=',outpath)
print('mLCPath=',mLCPath)
print('jLCPath=',jLCPath)
print('jmLCPath=',jmLCPath)
print('difficulty=',difficulty)
exit(2)
except Exception as e:
print(e)
exit(3)