from osgeo import ogr import os import argparse from osgeo import ogr, gdal import os import argparse import numpy as np from scipy.spatial import KDTree from DotaOperator import DotaObj,readDotaFile,writerDotaFile,createDota from glob import glob from pathlib import Path import shutil def find_tif_files_pathlib(directory): path = Path(directory) # 使用rglob递归匹配所有.tif和.tiff文件 tif_files = list(path.rglob('*.tiff')) # 将Path对象转换为字符串路径 return [str(file) for file in tif_files] def find_srcPath(srcFolder): root_path = Path(srcFolder) target_path = [folderpath for folderpath in root_path.rglob("*") if folderpath.is_dir() and folderpath.name=="0-原图"] tiff_files = [] for folderpath in target_path: tiff_files=tiff_files+find_tif_files_pathlib(folderpath) tiff_dict={} for filepath in tiff_files: rootname=Path(filepath).stem tiff_dict[rootname]=filepath return tiff_dict def read_tifInfo(path): dataset = gdal.Open(path) # 打开TIF文件 if dataset is None: print("无法打开文件,读取文件信息") return None, None, None cols = dataset.RasterXSize # 图像宽度 rows = dataset.RasterYSize # 图像高度 bands = dataset.RasterCount im_proj = dataset.GetProjection() # 获取投影信息 im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 # im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 print("行数:", rows) print("列数:", cols) print("波段:", bands) x1=im_Geotrans[0]+im_Geotrans[1]*0 x2=im_Geotrans[0]+im_Geotrans[1]*cols y1=im_Geotrans[3]+im_Geotrans[5]*0 y2=im_Geotrans[3]+im_Geotrans[5]*rows xmin=min(x1,x2) xmax=max(x1,x2) ymin=min(y1,y2) ymax=max(y1,y2) geoExtend=[xmin,ymin,xmax,ymax] del dataset # 关闭数据集 return im_proj, im_Geotrans,geoExtend def getshapefileInfo(shp_path): """ 将Shapefile转换为DOTA格式 :param shp_path: Shapefile文件路径 """ geom_points=[] print("shapefile: ",shp_path) # 注册所有驱动 ogr.RegisterAll() # 打开Shapefile文件 driver = ogr.GetDriverByName('ESRI Shapefile') datasource = driver.Open(shp_path, 0) if datasource is None: print("无法打开Shapefile文件") return print("layer count: ",datasource.GetLayerCount()) for layerid in range(datasource.GetLayerCount()): print("layer id: ",layerid) # 获取图层 layer = datasource.GetLayer(layerid) layer_defn=layer.GetLayerDefn() field_count=layer_defn.GetFieldCount() print("field_count:", field_count) for i in range (field_count): field_defn=layer_defn.GetFieldDefn(i) field_name=field_defn.GetName() field_type=field_defn.GetType() field_type_name=field_defn.GetFieldTypeName(field_type) print("field_name:", field_name, field_type_name, field_type_name) for feature in layer: geom = feature.GetGeometryRef() if geom.GetGeometryName() == 'POINT': x=geom.GetX() y=geom.GetY() geom_points.append([x,y]) return np.array(geom_points) def getTiffsInfo(tiffnames,folderpath): """ 获取所有影像的几何信息 Args: tiff_paths: tiff列表 Returns: """ tiffdict={} for tiff_name in tiffnames: if tiff_name.endswith(".tiff"): tiff_path=os.path.join(folderpath,tiff_name) im_proj, im_Geotrans, geoExtend=read_tifInfo(tiff_path) tiffdict[tiff_name]={"geoExtend":geoExtend,"geoTrans":im_Geotrans,"imProj":im_proj} return tiffdict pass def distanceCls(tiffinfos,infolderPath,shipPortTree): newlabelobjsresult={} for tiffname in tiffinfos: labelname=tiffname.replace(".tiff",".txt") labelpath=os.path.join(infolderPath,labelname) im_geptrans=tiffinfos[tiffname]["geoTrans"] # 逐一处理 labelobjs=readDotaFile(labelpath) newlabelobjs=[] for labelobj in labelobjs: try: labelcenterP=labelobj.getCenter() lon=im_geptrans[0]+im_geptrans[1]*labelcenterP[0]+im_geptrans[1]*labelcenterP[1] lat=im_geptrans[3]+im_geptrans[5]*labelcenterP[0]+im_geptrans[5]*labelcenterP[1] # 常规 MLC_distance,index=shipPortTree["MLC"].query([lon,lat],k=1) JLC_distance,index=shipPortTree["JLC"].query([lon,lat],k=1) # JMLC_distance,index=shipPortTree["JMLC"].query([lon,lat],k=1) # 处理-180 MLC_distance_180,index_180=shipPortTree["MLC"].query([lon-360,lat],k=1) JLC_distance_180,index_180=shipPortTree["JLC"].query([lon-360,lat],k=1) # JMLC_distance_180,index_180=shipPortTree["JMLC"].query([lon-360,lat],k=1) # 处理180 MLC_distance180,index180=shipPortTree["MLC"].query([lon+360,lat],k=1) JLC_distance180,index180=shipPortTree["JLC"].query([lon+360,lat],k=1) # JMLC_distance180,index180=shipPortTree["JMLC"].query([lon+360,lat],k=1) # 取距离最近的值 distanceArr=[MLC_distance,JLC_distance,MLC_distance180,JLC_distance180,MLC_distance180,JLC_distance180] minidx=np.argmin(distanceArr) clsname="JLC" if minidx %2 ==0 : clsname="MLC" else: clsname="JLC" mindistance = distanceArr[minidx] if mindistance * 110000 < 500 and clsname == "JLC": # 这种在范围外认为是失效的 clsname = "MLC" labelobj.clsname = clsname newlabelobjs.append( {"obj":labelobj,"distance":distanceArr[minidx]*110000,"clsname":clsname} ) except Exception as e: print(e) # 生成新的txt newlabelobjsresult[tiffname]=newlabelobjs return newlabelobjsresult def updateLabelFile(newOutputFolder,infolderPath,updatelabels): Jfolderpath=os.path.join(newOutputFolder,"J") Mfolderpath=os.path.join(newOutputFolder,"M") if not os.path.exists(Jfolderpath): os.makedirs(Jfolderpath) if not os.path.exists(Mfolderpath): os.makedirs(Mfolderpath) for tiffname in updatelabels: newlabelobjs=updatelabels[tiffname] updatelabelfiles=[ obj["obj"] for obj in newlabelobjs] JM=set([ obj["clsname"] for obj in newlabelobjs]) print(tiffname,"类别: ",JM) newlabelname=tiffname.replace(".tiff",".txt") targetFolderPath=Mfolderpath if "MLC" in JM: targetFolderPath=Mfolderpath elif "JLC" in JM: targetFolderPath=Jfolderpath newlabelpath=os.path.join(targetFolderPath,newlabelname) newtiffpath=os.path.join(targetFolderPath,tiffname) writerDotaFile(updatelabelfiles,newlabelpath) shutil.copyfile(os.path.join(infolderPath,tiffname),newtiffpath) print("copy : ", tiffname) pass def SpliteProcess(srcfolderpath,infolderpath,outfolderpath,MLCPath,JLCPath,JMLCPath): shipPort={ "MLC":getshapefileInfo(MLCPath), "JLC":getshapefileInfo(JLCPath), # "JMLC":getshapefileInfo(JMLCPath), # 舰船不区分 居民一体 } shipPortTree={ "MLC":KDTree(shipPort["MLC"]), "JLC":KDTree(shipPort["JLC"]), # "JMLC":KDTree(shipPort["JMLC"]), } # tiffds=find_srcPath(srcfolderpath) tifflist=os.listdir(infolderpath) tiffinfos=getTiffsInfo(tifflist,infolderpath) updatelabels=distanceCls(tiffinfos,infolderpath,shipPortTree) # 更新日志 updateLabelFile(outfolderpath, infolderpath, updatelabels) return True pass def getParams(): parser = argparse.ArgumentParser() parser.add_argument('-s','--srcfolder',type=str,default=r'R:\TYSAR-德清院\TYSAR-条带模式(SM)\港口', help='输入shapefile文件') parser.add_argument('-i','--intiffolder',type=str,default=r'D:\TYSAR-德清院\切片结果整理', help='输入shapefile文件') parser.add_argument('-o', '--outfolder',type=str,default=r'D:\TYSAR-德清院\军民区分结果', help='输出geojson文件') parser.add_argument('-m', '--mLC',type=str,help=r'MLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\港口(民船).shp') parser.add_argument('-j', '--jLC',type=str,help=r'JLC' ,default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军港.shp') parser.add_argument('-jm', '--jmLC',type=str,help=r'MJLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军民一体港口.shp') args = parser.parse_args() return args if __name__ == '__main__': try: parser = getParams() srcfolder=parser.srcfolder infolder=parser.intiffolder outfolder=parser.outfolder mLCPath=parser.mLC jLCPath=parser.jLC jmLCPath=parser.jmLC print('srcfolder=',srcfolder) print('infolder=',infolder) print('outfolder=',outfolder) print('mLCPath=',mLCPath) print('jLCPath=',jLCPath) print('jmLCPath=',jmLCPath) SpliteProcess(srcfolder,infolder,outfolder,mLCPath,jLCPath,jmLCPath) exit(2) except Exception as e: print(e) exit(3)