SpacetySliceTools/tools/SpliteShipPort_AA.py

274 lines
9.4 KiB
Python

from osgeo import ogr
import os
import argparse
from osgeo import ogr, gdal
import os
import argparse
import numpy as np
from scipy.spatial import KDTree
from DotaOperator import DotaObj,readDotaFile,writerDotaFile,createDota
from glob import glob
from pathlib import Path
import shutil
def find_tif_files_pathlib(directory):
path = Path(directory)
# 使用rglob递归匹配所有.tif和.tiff文件
tif_files = list(path.rglob('*.tiff'))
# 将Path对象转换为字符串路径
return [str(file) for file in tif_files]
def find_srcPath(srcFolder):
root_path = Path(srcFolder)
target_path = [folderpath for folderpath in root_path.rglob("*") if folderpath.is_dir() and folderpath.name=="0-原图"]
tiff_files = []
for folderpath in target_path:
tiff_files=tiff_files+find_tif_files_pathlib(folderpath)
tiff_dict={}
for filepath in tiff_files:
rootname=Path(filepath).stem
tiff_dict[rootname]=filepath
return tiff_dict
def read_tifInfo(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件,读取文件信息")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
# im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
x1=im_Geotrans[0]+im_Geotrans[1]*0
x2=im_Geotrans[0]+im_Geotrans[1]*cols
y1=im_Geotrans[3]+im_Geotrans[5]*0
y2=im_Geotrans[3]+im_Geotrans[5]*rows
xmin=min(x1,x2)
xmax=max(x1,x2)
ymin=min(y1,y2)
ymax=max(y1,y2)
geoExtend=[xmin,ymin,xmax,ymax]
del dataset # 关闭数据集
return im_proj, im_Geotrans,geoExtend
def getshapefileInfo(shp_path):
"""
将Shapefile转换为DOTA格式
:param shp_path: Shapefile文件路径
"""
geom_points=[]
print("shapefile: ",shp_path)
# 注册所有驱动
ogr.RegisterAll()
# 打开Shapefile文件
driver = ogr.GetDriverByName('ESRI Shapefile')
datasource = driver.Open(shp_path, 0)
if datasource is None:
print("无法打开Shapefile文件")
return
print("layer count: ",datasource.GetLayerCount())
for layerid in range(datasource.GetLayerCount()):
print("layer id: ",layerid)
# 获取图层
layer = datasource.GetLayer(layerid)
layer_defn=layer.GetLayerDefn()
field_count=layer_defn.GetFieldCount()
print("field_count:", field_count)
for i in range (field_count):
field_defn=layer_defn.GetFieldDefn(i)
field_name=field_defn.GetName()
field_type=field_defn.GetType()
field_type_name=field_defn.GetFieldTypeName(field_type)
print("field_name:", field_name, field_type_name, field_type_name)
for feature in layer:
geom = feature.GetGeometryRef()
if geom.GetGeometryName() == 'POINT':
x=geom.GetX()
y=geom.GetY()
geom_points.append([x,y])
return np.array(geom_points)
def getTiffsInfo(tiffnames,folderpath):
"""
获取所有影像的几何信息
Args:
tiff_paths: tiff列表
Returns:
"""
tiffdict={}
for tiff_name in tiffnames:
if tiff_name.endswith(".tiff"):
tiff_path=os.path.join(folderpath,tiff_name)
im_proj, im_Geotrans, geoExtend=read_tifInfo(tiff_path)
tiffdict[tiff_name]={"geoExtend":geoExtend,"geoTrans":im_Geotrans,"imProj":im_proj}
return tiffdict
pass
def distanceCls(tiffinfos,infolderPath,shipPortTree):
newlabelobjsresult={}
for tiffname in tiffinfos:
labelname=tiffname.replace(".tiff",".txt")
labelpath=os.path.join(infolderPath,labelname)
im_geptrans=tiffinfos[tiffname]["geoTrans"]
# 逐一处理
labelobjs=readDotaFile(labelpath)
newlabelobjs=[]
for labelobj in labelobjs:
try:
labelcenterP=labelobj.getCenter()
lon=im_geptrans[0]+im_geptrans[1]*labelcenterP[0]+im_geptrans[1]*labelcenterP[1]
lat=im_geptrans[3]+im_geptrans[5]*labelcenterP[0]+im_geptrans[5]*labelcenterP[1]
# 常规
MLC_distance,index=shipPortTree["MLC"].query([lon,lat],k=1)
JLC_distance,index=shipPortTree["JLC"].query([lon,lat],k=1)
# JMLC_distance,index=shipPortTree["JMLC"].query([lon,lat],k=1)
# 处理-180
MLC_distance_180,index_180=shipPortTree["MLC"].query([lon-360,lat],k=1)
JLC_distance_180,index_180=shipPortTree["JLC"].query([lon-360,lat],k=1)
# JMLC_distance_180,index_180=shipPortTree["JMLC"].query([lon-360,lat],k=1)
# 处理180
MLC_distance180,index180=shipPortTree["MLC"].query([lon+360,lat],k=1)
JLC_distance180,index180=shipPortTree["JLC"].query([lon+360,lat],k=1)
# JMLC_distance180,index180=shipPortTree["JMLC"].query([lon+360,lat],k=1)
# 取距离最近的值
distanceArr=[MLC_distance,JLC_distance,MLC_distance180,JLC_distance180,MLC_distance180,JLC_distance180]
minidx=np.argmin(distanceArr)
clsname="JLC"
if minidx %2 ==0 :
clsname="MLC"
else:
clsname="JLC"
mindistance = distanceArr[minidx]
if mindistance * 110000 < 500 and clsname == "JLC": # 这种在范围外认为是失效的
clsname = "MLC"
labelobj.clsname = clsname
newlabelobjs.append(
{"obj":labelobj,"distance":distanceArr[minidx]*110000,"clsname":clsname}
)
except Exception as e:
print(e)
# 生成新的txt
newlabelobjsresult[tiffname]=newlabelobjs
return newlabelobjsresult
def updateLabelFile(newOutputFolder,infolderPath,updatelabels):
Jfolderpath=os.path.join(newOutputFolder,"J")
Mfolderpath=os.path.join(newOutputFolder,"M")
if not os.path.exists(Jfolderpath):
os.makedirs(Jfolderpath)
if not os.path.exists(Mfolderpath):
os.makedirs(Mfolderpath)
for tiffname in updatelabels:
newlabelobjs=updatelabels[tiffname]
updatelabelfiles=[ obj["obj"] for obj in newlabelobjs]
JM=set([ obj["clsname"] for obj in newlabelobjs])
print(tiffname,"类别: ",JM)
newlabelname=tiffname.replace(".tiff",".txt")
targetFolderPath=Mfolderpath
if "MLC" in JM:
targetFolderPath=Mfolderpath
elif "JLC" in JM:
targetFolderPath=Jfolderpath
newlabelpath=os.path.join(targetFolderPath,newlabelname)
newtiffpath=os.path.join(targetFolderPath,tiffname)
writerDotaFile(updatelabelfiles,newlabelpath)
shutil.copyfile(os.path.join(infolderPath,tiffname),newtiffpath)
print("copy : ", tiffname)
pass
def SpliteProcess(srcfolderpath,infolderpath,outfolderpath,MLCPath,JLCPath,JMLCPath):
shipPort={
"MLC":getshapefileInfo(MLCPath),
"JLC":getshapefileInfo(JLCPath),
# "JMLC":getshapefileInfo(JMLCPath), # 舰船不区分 居民一体
}
shipPortTree={
"MLC":KDTree(shipPort["MLC"]),
"JLC":KDTree(shipPort["JLC"]),
# "JMLC":KDTree(shipPort["JMLC"]),
}
# tiffds=find_srcPath(srcfolderpath)
tifflist=os.listdir(infolderpath)
tiffinfos=getTiffsInfo(tifflist,infolderpath)
updatelabels=distanceCls(tiffinfos,infolderpath,shipPortTree)
# 更新日志
updateLabelFile(outfolderpath, infolderpath, updatelabels)
return True
pass
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-s','--srcfolder',type=str,default=r'R:\TYSAR-德清院\TYSAR-条带模式(SM)\港口', help='输入shapefile文件')
parser.add_argument('-i','--intiffolder',type=str,default=r'D:\TYSAR-德清院\切片结果整理', help='输入shapefile文件')
parser.add_argument('-o', '--outfolder',type=str,default=r'D:\TYSAR-德清院\军民区分结果', help='输出geojson文件')
parser.add_argument('-m', '--mLC',type=str,help=r'MLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\港口(民船).shp')
parser.add_argument('-j', '--jLC',type=str,help=r'JLC' ,default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军港.shp')
parser.add_argument('-jm', '--jmLC',type=str,help=r'MJLC', default=r'D:\TYSAR-德清院\目标点位信息更新\0828目标点位\军民一体港口.shp')
args = parser.parse_args()
return args
if __name__ == '__main__':
try:
parser = getParams()
srcfolder=parser.srcfolder
infolder=parser.intiffolder
outfolder=parser.outfolder
mLCPath=parser.mLC
jLCPath=parser.jLC
jmLCPath=parser.jmLC
print('srcfolder=',srcfolder)
print('infolder=',infolder)
print('outfolder=',outfolder)
print('mLCPath=',mLCPath)
print('jLCPath=',jLCPath)
print('jmLCPath=',jmLCPath)
SpliteProcess(srcfolder,infolder,outfolder,mLCPath,jLCPath,jmLCPath)
exit(2)
except Exception as e:
print(e)
exit(3)