import os import argparse from osgeo import ogr,gdal from matplotlib import pyplot as plt from osgeo import gdal import matplotlib import matplotlib.patches as patches from osgeo import gdal from PIL import Image from scipy.spatial import cKDTree import numpy as np from tools.DotaOperator import DotaObj,createDota,readDotaFile,writerDotaFile import argparse import math from math import ceil, floor ########################################################################## # 参数区 ########################################################################## SliceSize=5000 ########################################################################## # 函数区 ########################################################################## def read_tif(path): dataset = gdal.Open(path) # 打开TIF文件 if dataset is None: print("无法打开文件") return None, None, None cols = dataset.RasterXSize # 图像宽度 rows = dataset.RasterYSize # 图像高度 bands = dataset.RasterCount im_proj = dataset.GetProjection() # 获取投影信息 im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 print("行数:", rows) print("列数:", cols) print("波段:", bands) del dataset # 关闭数据集 return im_proj, im_Geotrans, im_data def Strech_SquareRoot(im_data): # 判断是否为dB # immask = np.isfinite(im_data) # imvail_data = im_data[immask] # minvalue = np.percentile(imvail_data,30) # if minvalue<0 : # im_data=np.power(10.0,im_data/10.0) im_data=np.sqrt(im_data) immask = np.isfinite(im_data) imvail_data = im_data[immask] minvalue=np.nanmin(imvail_data) maxvalue=np.nanmax(imvail_data) minvalue_01Prec = np.percentile(imvail_data, 0.1) # 20250904 1%拉伸 maxvalue_999Prec = np.percentile(imvail_data, 99.9) print('sqrt root min - max ', minvalue,maxvalue) if (maxvalue-minvalue)/(maxvalue_999Prec-minvalue_01Prec)>3: # 表示 拉伸之后,像素值绝大部分很有可能集中在 80 minvalue=minvalue_01Prec maxvalue=maxvalue_999Prec print('sqrt root min(0.1) - max(99.9) ', minvalue, maxvalue) im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 im_data = np.clip(im_data, 0, 255) return im_data.astype(np.uint8) def write_envi(im_data, im_geotrans, im_proj, output_path): """ 将数组数据写入ENVI格式文件 :param im_data: 输入的numpy数组(2D或3D) :param im_geotrans: 仿射变换参数(6元组) :param im_proj: 投影信息(WKT字符串) :param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr) """ im_bands = 1 im_height, im_width = im_data.shape # 创建ENVI格式驱动 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Float32) if dataset is not None: dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数 dataset.SetProjection(im_proj) # 设置投影 dataset.GetRasterBand(1).WriteArray(im_data) dataset.FlushCache() # 确保数据写入磁盘 dataset = None # 关闭文件 outfilepath=output_path.replace(".tiff",".png") im_data_uint8=Strech_SquareRoot(im_data) Image.fromarray(im_data_uint8).save(outfilepath, compress_level=0) def geoXY2pixelXY(geo_x, geo_y, inv_gt): pixel_x = inv_gt[0] + geo_x * inv_gt[1] + geo_y * inv_gt[2] pixel_y = inv_gt[3] + geo_x * inv_gt[4] + geo_y * inv_gt[5] return pixel_x, pixel_y def label2pixelpoints(dotapath,tiff_inv_trans,methodstr,filterlabels): dotalist = readDotaFile(dotapath,filterlabels) if methodstr=="geolabel": for i in range(len(dotalist)): geo_x = dotalist[i].x1 # x1 geo_y = dotalist[i].y1 pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) dotalist[i].x1 = pixel_x dotalist[i].y1 = pixel_y geo_x = dotalist[i].x2 # x2 geo_y = dotalist[i].y2 pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) dotalist[i].x2 = pixel_x dotalist[i].y2 = pixel_y geo_x = dotalist[i].x3 # x3 geo_y = dotalist[i].y3 pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) dotalist[i].x3 = pixel_x dotalist[i].y3 = pixel_y geo_x = dotalist[i].x4 # x4 geo_y = dotalist[i].y4 pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) dotalist[i].x4 = pixel_x dotalist[i].y4 = pixel_y print("点数:", len(dotalist)) return dotalist def getMaxEdge(dotalist, ids): cornpoint = np.zeros((len(ids) * 4, 2)) for idx in range(len(ids)): cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1 cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2 cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3 cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4 cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1 cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2 cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3 cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4 xedge = np.max(cornpoint[:, 0]) - np.min(cornpoint[:, 0]) yedge = np.max(cornpoint[:, 1]) - np.min(cornpoint[:, 1]) edgelen = xedge if xedge > yedge else yedge return edgelen def getExternCenter(dotalist, ids): cornpoint = np.zeros((len(ids) * 4, 2)) for idx in range(len(ids)): cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1 cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2 cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3 cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4 cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1 cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2 cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3 cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4 minX = np.min(cornpoint[:, 0]) minY = np.min(cornpoint[:, 1]) maxX = np.max(cornpoint[:, 0]) maxY = np.max(cornpoint[:, 1]) centerX = (minX + maxX) / 2 centerY = (minY + maxY) / 2 return [centerX, centerY, minX, minY, maxX, maxY] def drawSliceRasterPrivew(tiff_data,dotalist,clusterDict): # 绘制图形 # 创建图形和坐标轴 fig, ax = plt.subplots(figsize=(20, 16)) ax.imshow(tiff_data, cmap='gray') # 绘制每个目标的矩形框并标注坐标 for i in range(len(dotalist)): # 提取x和y坐标 x_coords = [dotalist[i].x1, dotalist[i].x2, dotalist[i].x3, dotalist[i].x4] y_coords = [dotalist[i].y1, dotalist[i].y2, dotalist[i].y3, dotalist[i].y4] # 计算最小外接矩形(AABB) x_min, x_max = min(x_coords), max(x_coords) y_min, y_max = min(y_coords), max(y_coords) width = x_max - x_min height = y_max - y_min # 绘制无填充矩形框(仅红色边框) rect = patches.Rectangle( (x_min, y_min), width, height, linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none' ) ax.add_patch(rect) # ax.annotate(f'({x},{y})', xy=(x, y), xytext=(5, 5), # textcoords='offset points', fontsize=10, # bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8)) # 在矩形中心标注目标编号 center_x = sum(x_coords) / 4 center_y = sum(y_coords) / 4 ax.text(center_x, center_y, str(i), ha='center', va='center', fontsize=6, color='red') # 以类别中心为中心绘制四边形 for k in clusterDict: # 绘制无填充矩形框(仅红色边框) minX = clusterDict[k]["p"][0] minY = clusterDict[k]["p"][1] rect = patches.Rectangle( (minX , minY), SliceSize, SliceSize, linewidth=2, edgecolor='green', facecolor='none' # 关键:facecolor='none' ) ax.add_patch(rect) ax.text(minX+512, minY+512, str(k), ha='center', va='center', fontsize=6, color='green') plt.tight_layout() plt.show() print("绘图结束") return None def find_optimal_slices(H, W, boxes, patch_size=1024, max_overlap_rate=0.2): """ Compute optimal slice positions for the image to maximize the number of fully contained rectangular patches (boxes), while ensuring the overlap rate between any two slices does not exceed the specified maximum. Parameters: - H: Height of the image. - W: Width of the image. - boxes: List of tuples or lists, each containing (x1, y1, x2, y2) where (x1, y1) is the top-left and (x2, y2) is the bottom-right of a rectangular patch. - patch_size: Size of each slice (square, e.g., 1024). - max_overlap_rate: Maximum allowed overlap rate (e.g., 0.2). Returns: - slices: List of (sx, sy) starting positions for the slices. - covered_count: Number of patches that appear fully in at least one slice. """ overlap_max = patch_size * max_overlap_rate stride = patch_size - floor(overlap_max) # Ensures overlap <= max_overlap_rate in linear dimensions N = len(boxes) x_covered = [set() for _ in range(stride)] y_covered = [set() for _ in range(stride)] for i in range(N): x1, y1, x2, y2 = boxes[i] b_w = x2 - x1 b_h = y2 - y1 # For x-dimension lx = max(0, x2 - patch_size) rx = min(W - patch_size, x1) if lx <= rx: start_x = ceil(lx) end_x = floor(rx) l_x = end_x - start_x + 1 if l_x > 0: if l_x >= stride: for ox in range(stride): x_covered[ox].add(i) else: for sx in range(start_x, end_x + 1): ox = sx % stride x_covered[ox].add(i) # For y-dimension ly = max(0, y2 - patch_size) ry = min(H - patch_size, y1) if ly <= ry: start_y = ceil(ly) end_y = floor(ry) l_y = end_y - start_y + 1 if l_y > 0: if l_y >= stride: for oy in range(stride): y_covered[oy].add(i) else: for sy in range(start_y, end_y + 1): oy = sy % stride y_covered[oy].add(i) # Find the best offset pair (ox, oy) that maximizes covered patches max_covered = 0 best_ox = 0 best_oy = 0 for ox in range(stride): for oy in range(stride): current_covered = len(x_covered[ox] & y_covered[oy]) if current_covered > max_covered: max_covered = current_covered best_ox = ox best_oy = oy # Generate the slice positions using the best offsets and stride slices = [] sx = best_ox while sx + patch_size <= W: sy = best_oy while sy + patch_size <= H: slices.append((sx, sy)) sy += stride sx += stride return slices, max_covered def check_B_in_A(A,B): """ 判断A包含B :param A: [x0,y0.w.h] :param B: [x0,y0.w.h] :return: """ # 解构矩形A和B的参数 Ax0, Ay0, Aw, Ah = A Bx0, By0, Bw, Bh = B # 计算矩形A和B的右边界和下边界 Ax1 = Ax0 + Aw Ay1 = Ay0 + Ah Bx1 = Bx0 + Bw By1 = By0 + Bh # 判断B是否完全在A内部 return (Bx0 >= Ax0) and (Bx1 <= Ax1) and (By0 >= Ay0) and (By1 <= Ay1) ########################################################################## # 切分算法流程图 ########################################################################## def getclusterDict(dotalist,imgheight,imgwidth,pitchSize=1024,max_overlap_rate=0.2): """ 生成切片数据 :param dotalist: 样本集 :param imgheight: 图像高度 :param imgwidth: 图像宽度 :return: 切片类型 """ boxs=[] for i in range(len(dotalist)): xs=np.array([dotalist[i].x1,dotalist[i].x2,dotalist[i].x3, dotalist[i].x4]) ys=np.array([dotalist[i].y1,dotalist[i].y2,dotalist[i].y3, dotalist[i].y4]) x1=np.min(xs) x2=np.max(xs) y1=np.min(ys) y2=np.max(ys) boxs.append([x1,y1,x2,y2]) # x1, y1, x2, y2 = boxes[i] slices, max_covered=find_optimal_slices(imgheight,imgwidth,boxs,pitchSize,max_overlap_rate) clusterDict={} waitContaindota=[] hasContainIds=[] for i in range(len(slices)): sx,sy=slices[i] clusterDict[i]={"p":[sx,sy],"id":[]} slicesExten=[sx,sy,SliceSize,SliceSize] for ids in range(len(dotalist)): if ids in hasContainIds: continue else: [centerX, centerY, minX, minY, maxX, maxY]=getExternCenter(dotalist, [ids]) dotaExtend=[minX,minY,maxX-minX,maxY-minY] if check_B_in_A(slicesExten,dotaExtend): print("True: ", slicesExten, dotaExtend) clusterDict[i]["id"].append(ids) # hasContainIds.append(ids) else: print("False: ",slicesExten,dotaExtend) for ids in range(len(dotalist)): if ids in hasContainIds: continue else: waitContaindota.append(ids) print("No in slice dota : ",str(dotalist[ids]) ) print("no process ids ",str(waitContaindota)) return clusterDict def drawSlictplot(clusterDict,dotalist,tiff_data,nrows=10,ncols=9): """ :param clusterDict: clusterDict[i]={"p":[sx,sy],"id":[]} :param dotalist: (x1, y1, x2, y2, x3, y3, x4, y4 clsname diffcule) :return: """ fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(20, 16)) plt.tight_layout(pad=3.0) print(tiff_data.shape) # 9*10 subid=0 for cid in clusterDict: sx,sy=clusterDict[cid]["p"] colid=subid//nrows rowid=subid%nrows subid=subid+1 ax = axes[rowid] ax.set_title(str(cid)) sliceData=tiff_data[sy:(sy+SliceSize),sx:(sx+SliceSize)] ax.imshow(sliceData, cmap='gray') for did in clusterDict[cid]["id"] : # 提取x和y坐标 x_coords = [dotalist[did].x1-sx, dotalist[did].x2-sx, dotalist[did].x3-sx, dotalist[did].x4-sx] y_coords = [dotalist[did].y1-sy, dotalist[did].y2-sy, dotalist[did].y3-sy, dotalist[did].y4-sy] # 计算最小外接矩形(AABB) x_min, x_max = min(x_coords), max(x_coords) y_min, y_max = min(y_coords), max(y_coords) width = x_max - x_min height = y_max - y_min # 绘制无填充矩形框(仅红色边框) rect = patches.Rectangle( (x_min, y_min), width, height, linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none' ) ax.add_patch(rect) # 在矩形中心标注目标编号 center_x = x_min+width/2 center_y = y_min+height/2 ax.text(center_x, center_y, str(did), ha='center', va='center', fontsize=6, color='red') plt.tight_layout() plt.show() print("绘图结束") return None def slictDataAndOutlabel(clusterDict,dotalist,tiff_data,tiff_basename,outfolderpath,im_geotrans, im_proj): """ 切分标签,输出结果与文件 :param clusterDict: :param dotalist: :param tiff_data: :param tiff_name: :param outfolderpath: :return: """ for cid in clusterDict: sx, sy = clusterDict[cid]["p"] if len(clusterDict[cid]["id"])==0: continue sliceData = tiff_data[sy:(sy + SliceSize), sx:(sx + SliceSize)] outbinname="{}_{}.tiff".format(tiff_basename,cid) outlabelname="{}_{}.txt".format(tiff_basename,cid) # 获取样本列表 outdotalist=[] for did in clusterDict[cid]["id"] : x1=dotalist[did].x1-sx x2=dotalist[did].x2-sx x3=dotalist[did].x3-sx x4=dotalist[did].x4-sx y1=dotalist[did].y1-sy y2=dotalist[did].y2-sy y3=dotalist[did].y3-sy y4=dotalist[did].y4-sy tempdota=createDota(x1,y1,x2,y2,x3,y3,x4,y4,dotalist[did].clsname,dotalist[did].difficulty) outdotalist.append(tempdota) outlabelpath=os.path.join(outfolderpath,outlabelname) outbinpath=os.path.join(outfolderpath,outbinname) temp_im_geotrans=[tempi for tempi in im_geotrans] # 处理 0,3 temp_im_geotrans[0]=im_geotrans[0]+sx*im_geotrans[1]+im_geotrans[2]*sy # x temp_im_geotrans[3]=im_geotrans[3]+sx*im_geotrans[4]+im_geotrans[5]*sy # y write_envi(sliceData,temp_im_geotrans,im_proj,outbinpath) writerDotaFile(outdotalist,outlabelpath) ########################################################################## # 处理流程图 ########################################################################## def DataSampleSliceRasterProcess(inbinfile,labelfilepath,outfolderpath,methodstr,filterlabels): tiff_proj, tiff_trans, tiff_data = read_tif(inbinfile) tiff_inv_trans = gdal.InvGeoTransform(tiff_trans) dotalist=label2pixelpoints(labelfilepath,tiff_inv_trans,methodstr,filterlabels) imgheight, imgwidth=tiff_data.shape clusterDict=getclusterDict(dotalist,imgheight,imgwidth,SliceSize,0.25) # drawSliceRasterPrivew(tiff_data, dotalist, clusterDict) nrows=int(len(clusterDict)/1+1) # drawSlictplot(clusterDict, dotalist, tiff_data, nrows, 1) tiff_name=os.path.basename(inbinfile) tiff_basename=os.path.splitext(tiff_name)[0] slictDataAndOutlabel(clusterDict, dotalist, tiff_data, tiff_basename, outfolderpath, tiff_trans, tiff_proj) pass def getParams(): parser = argparse.ArgumentParser() parser.add_argument('-i','--inbinfile',type=str,default=r'D:\港口\Geo_bc2-sm-org-vv-20231016t135315-008424-0020e8-01.tiff', help='输入tiff的bin文件') parser.add_argument('-l', '--labelfilepath',type=str,default=r"D:\港口\港口dota\Geo_bc2-sm-org-vv-20231016t135315-008424-0020e8-01.military_harbor.txt", help='输入标注') parser.add_argument('-o', '--outfolderpath',type=str,default=r'D:\港口\切片结果', help='切片文件夹地址') parser.add_argument('-f', '--filterlabel',type=str,default=r'mix_airport;civil_harbor;military_harbor;no_harbor', help='标签过滤') group = parser.add_mutually_exclusive_group() group.add_argument( '--geolabel', action='store_const', const='geolabel', dest='method', help='标注坐标点为地理坐标' ) group.add_argument( '--pixellabel', action='store_const', const='pixellabel', dest='method', help='标注坐标系统为输入影像的像空间坐标' ) parser.set_defaults(method='geolabel') args = parser.parse_args() return args if __name__ == '__main__': try: parser = getParams() inbinfile=parser.inbinfile labelfilepath=parser.labelfilepath outfolderpath=parser.outfolderpath methodstr= parser.method filterlabels=parser.filterlabel.strip().split(';') print('inbinfile=',inbinfile) print('labelfilepath=',labelfilepath) print('outfolderpath=',outfolderpath) print('methodstr=',methodstr) print('filterlabels=',filterlabels) DataSampleSliceRasterProcess(inbinfile, labelfilepath, outfolderpath,methodstr,filterlabels) print("样本切分完成") exit(2) except Exception as e: print(e) exit(3)