diff --git a/LabelPortShipRasterSlice/Portshapefile2dota.py b/LabelPortShipRasterSlice/Portshapefile2dota.py new file mode 100644 index 0000000..ceaa3da --- /dev/null +++ b/LabelPortShipRasterSlice/Portshapefile2dota.py @@ -0,0 +1,446 @@ +from osgeo import ogr +import os +import argparse +from osgeo import ogr +import os +import argparse +from osgeo import ogr, gdal +import os +import argparse +import numpy as np +from scipy.spatial import KDTree +from tools.DotaOperator import DotaObj,readDotaFile,writerDotaFile,createDota +from glob import glob +from pathlib import Path +import shutil + + +MLCName="MLC" # M +JLCName="JLC" # J +MJLCName="MJLC" # JM 混合 +NOLCName="NOLC" # 没有港口 + + + +def find_tif_files_pathlib(directory): + path = Path(directory) + # 使用rglob递归匹配所有.tif和.tiff文件 + tif_files = list(path.rglob('*.tiff'))+list(path.rglob('*.tif')) + # 将Path对象转换为字符串路径 + return [str(file) for file in tif_files] + + +def find_srcPath(srcFolder): + root_path = Path(srcFolder) + target_path = [folderpath for folderpath in root_path.rglob("*") if folderpath.is_dir() and folderpath.name=="0-原图"] + tiff_files = [] + for folderpath in target_path: + tiff_files=tiff_files+find_tif_files_pathlib(folderpath) + + tiff_dict={} + for filepath in tiff_files: + rootname=Path(filepath).stem + tiff_dict[rootname]=filepath + + return tiff_dict + + + + + +def read_tifInfo(path): + dataset = gdal.Open(path) # 打开TIF文件 + if dataset is None: + print("无法打开文件,读取文件信息") + return None, None, None + + cols = dataset.RasterXSize # 图像宽度 + rows = dataset.RasterYSize # 图像高度 + bands = dataset.RasterCount + im_proj = dataset.GetProjection() # 获取投影信息 + im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 + # im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 + print("行数:", rows) + print("列数:", cols) + print("波段:", bands) + x1=im_Geotrans[0]+im_Geotrans[1]*0 + x2=im_Geotrans[0]+im_Geotrans[1]*cols + + y1=im_Geotrans[3]+im_Geotrans[5]*0 + y2=im_Geotrans[3]+im_Geotrans[5]*rows + + xmin=min(x1,x2) + xmax=max(x1,x2) + ymin=min(y1,y2) + ymax=max(y1,y2) + + geoExtend=[xmin,ymin,xmax,ymax] + del dataset # 关闭数据集 + return im_proj, im_Geotrans,geoExtend + +def getshapefileInfo(shp_path): + """ + 将Shapefile转换为DOTA格式 + :param shp_path: Shapefile文件路径 + """ + + geom_points=[] + + print("shapefile: ",shp_path) + + # 注册所有驱动 + ogr.RegisterAll() + + # 打开Shapefile文件 + driver = ogr.GetDriverByName('ESRI Shapefile') + datasource = driver.Open(shp_path, 0) + if datasource is None: + print("无法打开Shapefile文件") + return + + print("layer count: ",datasource.GetLayerCount()) + for layerid in range(datasource.GetLayerCount()): + print("layer id: ",layerid) + # 获取图层 + layer = datasource.GetLayer(layerid) + layer_defn=layer.GetLayerDefn() + field_count=layer_defn.GetFieldCount() + + print("field_count:", field_count) + for i in range (field_count): + field_defn=layer_defn.GetFieldDefn(i) + field_name=field_defn.GetName() + field_type=field_defn.GetType() + field_type_name=field_defn.GetFieldTypeName(field_type) + print("field_name:", field_name, field_type_name, field_type_name) + + for feature in layer: + geom = feature.GetGeometryRef() + if geom.GetGeometryName() == 'POINT': + x=geom.GetX() + y=geom.GetY() + geom_points.append([x,y]) + return np.array(geom_points) + + +def getTiffsInfo(tiffnames,folderpath): + """ + 获取所有影像的几何信息 + Args: + tiff_paths: tiff列表 + + Returns: + + """ + tiffdict={} + for tiff_name in tiffnames: + if tiff_name.endswith(".tiff"): + tiff_path=os.path.join(folderpath,tiff_name) + im_proj, im_Geotrans, geoExtend=read_tifInfo(tiff_path) + tiffdict[tiff_name]={"geoExtend":geoExtend,"geoTrans":im_Geotrans,"imProj":im_proj} + return tiffdict + + +def getMJSignal(tiffpath,shipPortTree,outfolderPath): + rootname=Path(tiffpath).stem + portTxtpath=os.path.join(outfolderPath,rootname+".txt") + im_proj, im_Geotrans, geoExtend = read_tifInfo(tiffpath) # geoExtend : [xmin,ymin,xmax,ymax] + [xmin, ymin, xmax, ymax]=geoExtend + center_x = (xmin + xmax) / 2.0 + center_y = (ymin + ymax) / 2.0 + center_point = [center_x, center_y] + # 2. 计算能够覆盖整个矩形区域的最小半径(中心点到任一角点的最大距离) + radius_to_corner = np.sqrt((xmax - center_x) ** 2 + (ymax - center_y) ** 2) + + MLCFlag=False + JLCFlag=False + ## MLC + if MLCName in shipPortTree and not shipPortTree[MLCName] is None: + # 3. 使用 query_ball_point 查找以中心点为圆心,radius_to_corner 为半径的圆内的所有点的索引 + potential_indices = shipPortTree[MLCName].query_ball_point(center_point, r=radius_to_corner) + + # 4. 获取这些潜在点的实际坐标 + # 假设你的 KDTree 是从 data_points 构建的:MLCTree = KDTree(data_points) + potential_points = shipPortTree[MLCName].data[potential_indices] # 这是所有潜在点的坐标数组 + + # 5. 进行精确的矩形范围过滤 + # 条件判断:x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间 + x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax) + y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax) + within_rect_indices_mask = x_in_range & y_in_range + + # 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]) + final_points = potential_points[within_rect_indices_mask] + + # final_points 就是你要的矩形范围内的点 + # 如果你需要的是这些点在原始数据中的索引,而不是坐标本身: + final_indices = np.array(potential_indices)[within_rect_indices_mask] + if final_points.shape[0]>0: + MLCFlag=True + with open(portTxtpath,"w",encoding="utf-8") as f: + for i in range(final_points.shape[0]): + f.write("{}\t\t{},{}\n".format("MLC",final_points[i,0],final_points[i,1])) + pass + if JLCName in shipPortTree and not shipPortTree[JLCName] is None: + # 3. 使用 query_ball_point 查找以中心点为圆心,radius_to_corner 为半径的圆内的所有点的索引 + potential_indices = shipPortTree[JLCName].query_ball_point(center_point, r=radius_to_corner) + + # 4. 获取这些潜在点的实际坐标 + # 假设你的 KDTree 是从 data_points 构建的:MLCTree = KDTree(data_points) + potential_points = shipPortTree[JLCName].data[potential_indices] # 这是所有潜在点的坐标数组 + + # 5. 进行精确的矩形范围过滤 + # 条件判断:x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间 + x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax) + y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax) + within_rect_indices_mask = x_in_range & y_in_range + + # 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]) + final_points = potential_points[within_rect_indices_mask] + + # final_points 就是你要的矩形范围内的点 + # 如果你需要的是这些点在原始数据中的索引,而不是坐标本身: + final_indices = np.array(potential_indices)[within_rect_indices_mask] + if final_points.shape[0]>0: + JLCFlag=True + with open(portTxtpath,"a",encoding="utf-8") as f: + for i in range(final_points.shape[0]): + f.write("{}\t\t{},{}\n".format("JLC",final_points[i,0],final_points[i,1])) + pass + # 处理软件 + return MLCFlag,JLCFlag + + +def getTiffInPort(shipPortTree,srcFolderPath_0img,outTiffInfoFilePath,outfolderPath): + tiffpaths=find_tif_files_pathlib(srcFolderPath_0img) + tiffLCPort={ + MLCName:[], + JLCName:[], + MJLCName:[], + NOLCName:[] + } + for tiffpath in tiffpaths: + MLCFlag,JLCFlag=getMJSignal(tiffpath,shipPortTree,outfolderPath) + + if MLCFlag and JLCFlag: + tiffLCPort[MJLCName].append(tiffpath) + elif MLCFlag: + tiffLCPort[MLCName].append(tiffpath) + elif JLCFlag: + tiffLCPort[JLCName].append(tiffpath) + else: + tiffLCPort[NOLCName].append(tiffpath) + + # 输出文件 + with open(outTiffInfoFilePath,'w',encoding="utf-8") as f: + for k in tiffLCPort: + for tiffpath in tiffLCPort[k]: + f.write("{}\t\t{}\n".format(k,tiffpath)) + + + + +def getMJSignal(geoExtend,shipPortTree): + [xmin, ymin, xmax, ymax]=geoExtend + center_x = (xmin + xmax) / 2.0 + center_y = (ymin + ymax) / 2.0 + center_point = [center_x, center_y] + # 2. 计算能够覆盖整个矩形区域的最小半径(中心点到任一角点的最大距离) + radius_to_corner = np.sqrt((xmax - center_x) ** 2 + (ymax - center_y) ** 2) + + MLCFlag=False + JLCFlag=False + ## MLC + if MLCName in shipPortTree and not shipPortTree[MLCName] is None: + # 3. 使用 query_ball_point 查找以中心点为圆心,radius_to_corner 为半径的圆内的所有点的索引 + potential_indices = shipPortTree[MLCName].query_ball_point(center_point, r=radius_to_corner) + + # 4. 获取这些潜在点的实际坐标 + # 假设你的 KDTree 是从 data_points 构建的:MLCTree = KDTree(data_points) + potential_points = shipPortTree[MLCName].data[potential_indices] # 这是所有潜在点的坐标数组 + + # 5. 进行精确的矩形范围过滤 + # 条件判断:x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间 + x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax) + y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax) + within_rect_indices_mask = x_in_range & y_in_range + + # 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]) + final_points = potential_points[within_rect_indices_mask] + + # final_points 就是你要的矩形范围内的点 + # 如果你需要的是这些点在原始数据中的索引,而不是坐标本身: + final_indices = np.array(potential_indices)[within_rect_indices_mask] + if final_points.shape[0]>0: + MLCFlag=True + # with open(portTxtpath,"w",encoding="utf-8") as f: + # for i in range(final_points.shape[0]): + # f.write("{}\t\t{},{}\n".format("MLC",final_points[i,0],final_points[i,1])) + # pass + if JLCName in shipPortTree and not shipPortTree[JLCName] is None: + # 3. 使用 query_ball_point 查找以中心点为圆心,radius_to_corner 为半径的圆内的所有点的索引 + potential_indices = shipPortTree[JLCName].query_ball_point(center_point, r=radius_to_corner) + + # 4. 获取这些潜在点的实际坐标 + # 假设你的 KDTree 是从 data_points 构建的:MLCTree = KDTree(data_points) + potential_points = shipPortTree[JLCName].data[potential_indices] # 这是所有潜在点的坐标数组 + + # 5. 进行精确的矩形范围过滤 + # 条件判断:x 坐标在 xmin 和 xmax 之间,且 y 坐标在 ymin 和 ymax 之间 + x_in_range = (potential_points[:, 0] >= xmin) & (potential_points[:, 0] <= xmax) + y_in_range = (potential_points[:, 1] >= ymin) & (potential_points[:, 1] <= ymax) + within_rect_indices_mask = x_in_range & y_in_range + + # 6. 获取最终在矩形范围内的点的坐标(在原始 data_points 中的索引是 potential_indices[within_rect_indices_mask]) + final_points = potential_points[within_rect_indices_mask] + + # final_points 就是你要的矩形范围内的点 + # 如果你需要的是这些点在原始数据中的索引,而不是坐标本身: + final_indices = np.array(potential_indices)[within_rect_indices_mask] + if final_points.shape[0]>0: + JLCFlag=True + # with open(portTxtpath,"a",encoding="utf-8") as f: + # for i in range(final_points.shape[0]): + # f.write("{}\t\t{},{}\n".format("JLC",final_points[i,0],final_points[i,1])) + # pass + # 处理软件 + if MLCFlag and JLCFlag: + return MJLCName + # tiffLCPort[MJLCName].append(tiffpath) + elif MLCFlag: + return MLCName + # tiffLCPort[MLCName].append(tiffpath) + elif JLCFlag: + return JLCName + # tiffLCPort[JLCName].append(tiffpath) + else: + return NOLCName + # tiffLCPort[NOLCName].append(tiffpath) + # return MLCFlag,JLCFlag + + + +def shapefile_to_dota(shp_path, output_path, shipPortTree,difficulty_value=1): + """ + 将Shapefile转换为DOTA格式 + :param shp_path: Shapefile文件路径 + :param output_path: 输出目录 + :param class_field: 类别字段名 + :param difficulty_value: 难度默认字段 + """ + + # 注册所有驱动 + ogr.RegisterAll() + + # 打开Shapefile文件 + driver = ogr.GetDriverByName('ESRI Shapefile') + datasource = driver.Open(shp_path, 0) + if datasource is None: + print("无法打开Shapefile文件") + return + + # 获取图层 + layer = datasource.GetLayer() + + output_file = output_path + + with open(output_file, 'w',encoding="utf-8") as f: + # 写入DOTA格式头信息(可选) + # f.write('imagesource:unknown\n') + # f.write('gsd:1.0\n') + + # 遍历所有要素 + for feature in layer: + # 获取几何对象 + geom = feature.GetGeometryRef() + if geom is None: + continue + # 获取类别和难度 + try: + class_name = 'unknown' + except Exception as e: + class_name="MLC" + print(e) + difficulty = difficulty_value + + # 处理不同类型的几何图形 + if geom.GetGeometryName() == 'POLYGON': + # 获取多边形外环 + ring = geom.GetGeometryRef(0) + # 获取所有点 + points = [] + for i in range(ring.GetPointCount()): + points.append(ring.GetPoint(i)) + + # 确保有足够的点(至少4个) + if len(points) >= 4: + # 取前4个点作为DOTA格式的四个角点 + # 注意: DOTA要求按顺序排列(顺时针或逆时针) + x1, y1 = points[0][0], points[0][1] + x2, y2 = points[1][0], points[1][1] + x3, y3 = points[2][0], points[2][1] + x4, y4 = points[3][0], points[3][1] + + xmin = min(x1, x2,x3,x4) + xmax = max(x1, x2,x3,x4) + ymin = min(y1, y2,y3,y4) + ymax = max(y1, y2,y3,y4) + # [xmin, ymin, xmax, ymax] = geoExtend + geoExtend = [xmin, ymin, xmax, ymax] + class_name=getMJSignal(geoExtend, shipPortTree) + # 写入DOTA格式行 + line = f"{x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4} {class_name} {difficulty}\n" + f.write(line) + + # 释放资源 + datasource.Destroy() + print("转换完毕") + + +def PortShapeProces(shp_path, output_path, MLCPath,JLCPath,JMLCPath,difficulty_value=1): + shipPort={ + MLCName:getshapefileInfo(MLCPath), + JLCName:getshapefileInfo(JLCPath), + MJLCName:getshapefileInfo(JMLCPath), # 舰船不区分 居民一体 + } + + shipPortTree={ + MLCName:KDTree(shipPort[MLCName]), + JLCName:KDTree(shipPort[JLCName]), + MJLCName:KDTree(shipPort[MJLCName]), + } + shapefile_to_dota(shp_path, output_path,shipPortTree, difficulty) + + + +def getParams(): + parser = argparse.ArgumentParser() + parser.add_argument('-i','--infile',type=str,default=r'D:\Annotation_Y\港口\聚束模式\20250505_sp\bc3-sp-org-vv-20250410t053930-020615-000034-005087-01_LC.shp', help='输入shapefile文件') + parser.add_argument('-o', '--outfile',type=str,default=r'D:\Annotation_Y\港口\聚束模式\20250505_sp\bc3-sp-org-vv-20250410t053930-020615-000034-005087-01_LC.txt', help='输出geojson文件') + parser.add_argument('-m', '--mLC',type=str,help=r'MLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\港口(民船).shp') + parser.add_argument('-j', '--jLC',type=str,help=r'JLC' ,default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军港.shp') + parser.add_argument('-jm', '--jmlc',type=str,help=r'JMLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军民一体港口.shp') + parser.add_argument('-d', '--difficulty',type=int,default=1, help='输出geojson文件') + args = parser.parse_args() + return args + +if __name__ == '__main__': + try: + parser = getParams() + inFilePath=parser.infile + outpath=parser.outfile + mLCPath=parser.mLC + jLCPath=parser.jLC + jmLCPath=parser.jmlc + difficulty=parser.difficulty + print('infile=',inFilePath) + print('outfile=',outpath) + print('mLCPath=',mLCPath) + print('jLCPath=',jLCPath) + print('jmLCPath=',jmLCPath) + print('difficulty=',difficulty) + + exit(2) + except Exception as e: + print(e) + exit(3)