diff --git a/LabelPortShipRasterSlice/DataSamplePortSliceRaster_AA.py b/LabelPortShipRasterSlice/DataSamplePortSliceRaster_AA.py new file mode 100644 index 0000000..cdc17c9 --- /dev/null +++ b/LabelPortShipRasterSlice/DataSamplePortSliceRaster_AA.py @@ -0,0 +1,557 @@ +import os +import argparse +from osgeo import ogr,gdal +from matplotlib import pyplot as plt +from osgeo import gdal +import matplotlib +import matplotlib.patches as patches +from osgeo import gdal +from PIL import Image +from scipy.spatial import cKDTree +import numpy as np +from tools.DotaOperator import DotaObj,createDota,readDotaFile,writerDotaFile +import argparse +import math +from math import ceil, floor + +########################################################################## +# 参数区 +########################################################################## +SliceSize=5000 + +########################################################################## +# 函数区 +########################################################################## +def read_tif(path): + dataset = gdal.Open(path) # 打开TIF文件 + if dataset is None: + print("无法打开文件") + return None, None, None + + cols = dataset.RasterXSize # 图像宽度 + rows = dataset.RasterYSize # 图像高度 + bands = dataset.RasterCount + im_proj = dataset.GetProjection() # 获取投影信息 + im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 + im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 + print("行数:", rows) + print("列数:", cols) + print("波段:", bands) + del dataset # 关闭数据集 + return im_proj, im_Geotrans, im_data + + +def Strech_SquareRoot(im_data): + # 判断是否为dB + # immask = np.isfinite(im_data) + # imvail_data = im_data[immask] + # minvalue = np.percentile(imvail_data,30) + # if minvalue<0 : + # im_data=np.power(10.0,im_data/10.0) + + im_data=np.sqrt(im_data) + immask = np.isfinite(im_data) + imvail_data = im_data[immask] + + minvalue=np.nanmin(imvail_data) + maxvalue=np.nanmax(imvail_data) + minvalue_01Prec = np.percentile(imvail_data, 0.1) # 20250904 1%拉伸 + maxvalue_999Prec = np.percentile(imvail_data, 99.9) + print('sqrt root min - max ', minvalue,maxvalue) + if (maxvalue-minvalue)/(maxvalue_999Prec-minvalue_01Prec)>3: # 表示 拉伸之后,像素值绝大部分很有可能集中在 80 + minvalue=minvalue_01Prec + maxvalue=maxvalue_999Prec + print('sqrt root min(0.1) - max(99.9) ', minvalue, maxvalue) + + + im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 + im_data = np.clip(im_data, 0, 255) + + return im_data.astype(np.uint8) + +def write_envi(im_data, im_geotrans, im_proj, output_path): + """ + 将数组数据写入ENVI格式文件 + :param im_data: 输入的numpy数组(2D或3D) + :param im_geotrans: 仿射变换参数(6元组) + :param im_proj: 投影信息(WKT字符串) + :param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr) + """ + im_bands = 1 + im_height, im_width = im_data.shape + # 创建ENVI格式驱动 + driver = gdal.GetDriverByName("GTiff") + dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Float32) + + if dataset is not None: + dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数 + dataset.SetProjection(im_proj) # 设置投影 + + dataset.GetRasterBand(1).WriteArray(im_data) + + dataset.FlushCache() # 确保数据写入磁盘 + dataset = None # 关闭文件 + outfilepath=output_path.replace(".tiff",".png") + im_data_uint8=Strech_SquareRoot(im_data) + Image.fromarray(im_data_uint8).save(outfilepath, compress_level=0) + +def geoXY2pixelXY(geo_x, geo_y, inv_gt): + pixel_x = inv_gt[0] + geo_x * inv_gt[1] + geo_y * inv_gt[2] + pixel_y = inv_gt[3] + geo_x * inv_gt[4] + geo_y * inv_gt[5] + return pixel_x, pixel_y + +def label2pixelpoints(dotapath,tiff_inv_trans,methodstr,filterlabels): + dotalist = readDotaFile(dotapath,filterlabels) + if methodstr=="geolabel": + for i in range(len(dotalist)): + geo_x = dotalist[i].x1 # x1 + geo_y = dotalist[i].y1 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x1 = pixel_x + dotalist[i].y1 = pixel_y + + geo_x = dotalist[i].x2 # x2 + geo_y = dotalist[i].y2 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x2 = pixel_x + dotalist[i].y2 = pixel_y + + geo_x = dotalist[i].x3 # x3 + geo_y = dotalist[i].y3 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x3 = pixel_x + dotalist[i].y3 = pixel_y + + geo_x = dotalist[i].x4 # x4 + geo_y = dotalist[i].y4 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x4 = pixel_x + dotalist[i].y4 = pixel_y + + print("点数:", len(dotalist)) + return dotalist + +def getMaxEdge(dotalist, ids): + cornpoint = np.zeros((len(ids) * 4, 2)) + for idx in range(len(ids)): + cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1 + cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2 + cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3 + cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4 + + cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1 + cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2 + cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3 + cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4 + + xedge = np.max(cornpoint[:, 0]) - np.min(cornpoint[:, 0]) + yedge = np.max(cornpoint[:, 1]) - np.min(cornpoint[:, 1]) + + edgelen = xedge if xedge > yedge else yedge + return edgelen + +def getExternCenter(dotalist, ids): + cornpoint = np.zeros((len(ids) * 4, 2)) + for idx in range(len(ids)): + cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1 + cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2 + cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3 + cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4 + + cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1 + cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2 + cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3 + cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4 + + minX = np.min(cornpoint[:, 0]) + minY = np.min(cornpoint[:, 1]) + maxX = np.max(cornpoint[:, 0]) + maxY = np.max(cornpoint[:, 1]) + centerX = (minX + maxX) / 2 + centerY = (minY + maxY) / 2 + return [centerX, centerY, minX, minY, maxX, maxY] + +def drawSliceRasterPrivew(tiff_data,dotalist,clusterDict): + # 绘制图形 + # 创建图形和坐标轴 + fig, ax = plt.subplots(figsize=(20, 16)) + ax.imshow(tiff_data, cmap='gray') + # 绘制每个目标的矩形框并标注坐标 + for i in range(len(dotalist)): + # 提取x和y坐标 + x_coords = [dotalist[i].x1, dotalist[i].x2, dotalist[i].x3, dotalist[i].x4] + y_coords = [dotalist[i].y1, dotalist[i].y2, dotalist[i].y3, dotalist[i].y4] + + # 计算最小外接矩形(AABB) + x_min, x_max = min(x_coords), max(x_coords) + y_min, y_max = min(y_coords), max(y_coords) + width = x_max - x_min + height = y_max - y_min + + # 绘制无填充矩形框(仅红色边框) + rect = patches.Rectangle( + (x_min, y_min), width, height, + linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none' + ) + ax.add_patch(rect) + + # ax.annotate(f'({x},{y})', xy=(x, y), xytext=(5, 5), + # textcoords='offset points', fontsize=10, + # bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8)) + + # 在矩形中心标注目标编号 + center_x = sum(x_coords) / 4 + center_y = sum(y_coords) / 4 + ax.text(center_x, center_y, str(i), + ha='center', va='center', fontsize=6, color='red') + + # 以类别中心为中心绘制四边形 + for k in clusterDict: + # 绘制无填充矩形框(仅红色边框) + minX = clusterDict[k]["p"][0] + minY = clusterDict[k]["p"][1] + rect = patches.Rectangle( + (minX , minY), SliceSize, SliceSize, + linewidth=2, edgecolor='green', facecolor='none' # 关键:facecolor='none' + ) + ax.add_patch(rect) + ax.text(minX+512, minY+512, str(k), + ha='center', va='center', fontsize=6, color='green') + + plt.tight_layout() + plt.show() + + print("绘图结束") + return None + +def find_optimal_slices(H, W, boxes, patch_size=1024, max_overlap_rate=0.2): + """ + Compute optimal slice positions for the image to maximize the number of fully contained rectangular patches (boxes), + while ensuring the overlap rate between any two slices does not exceed the specified maximum. + + Parameters: + - H: Height of the image. + - W: Width of the image. + - boxes: List of tuples or lists, each containing (x1, y1, x2, y2) where (x1, y1) is the top-left and (x2, y2) is the bottom-right of a rectangular patch. + - patch_size: Size of each slice (square, e.g., 1024). + - max_overlap_rate: Maximum allowed overlap rate (e.g., 0.2). + + Returns: + - slices: List of (sx, sy) starting positions for the slices. + - covered_count: Number of patches that appear fully in at least one slice. + """ + overlap_max = patch_size * max_overlap_rate + stride = patch_size - floor(overlap_max) # Ensures overlap <= max_overlap_rate in linear dimensions + + N = len(boxes) + x_covered = [set() for _ in range(stride)] + y_covered = [set() for _ in range(stride)] + + for i in range(N): + x1, y1, x2, y2 = boxes[i] + b_w = x2 - x1 + b_h = y2 - y1 + + # For x-dimension + lx = max(0, x2 - patch_size) + rx = min(W - patch_size, x1) + if lx <= rx: + start_x = ceil(lx) + end_x = floor(rx) + l_x = end_x - start_x + 1 + if l_x > 0: + if l_x >= stride: + for ox in range(stride): + x_covered[ox].add(i) + else: + for sx in range(start_x, end_x + 1): + ox = sx % stride + x_covered[ox].add(i) + + # For y-dimension + ly = max(0, y2 - patch_size) + ry = min(H - patch_size, y1) + if ly <= ry: + start_y = ceil(ly) + end_y = floor(ry) + l_y = end_y - start_y + 1 + if l_y > 0: + if l_y >= stride: + for oy in range(stride): + y_covered[oy].add(i) + else: + for sy in range(start_y, end_y + 1): + oy = sy % stride + y_covered[oy].add(i) + + # Find the best offset pair (ox, oy) that maximizes covered patches + max_covered = 0 + best_ox = 0 + best_oy = 0 + for ox in range(stride): + for oy in range(stride): + current_covered = len(x_covered[ox] & y_covered[oy]) + if current_covered > max_covered: + max_covered = current_covered + best_ox = ox + best_oy = oy + + # Generate the slice positions using the best offsets and stride + slices = [] + sx = best_ox + while sx + patch_size <= W: + sy = best_oy + while sy + patch_size <= H: + slices.append((sx, sy)) + sy += stride + sx += stride + + return slices, max_covered + +def check_B_in_A(A,B): + """ + 判断A包含B + :param A: [x0,y0.w.h] + :param B: [x0,y0.w.h] + :return: + """ + # 解构矩形A和B的参数 + Ax0, Ay0, Aw, Ah = A + Bx0, By0, Bw, Bh = B + + # 计算矩形A和B的右边界和下边界 + Ax1 = Ax0 + Aw + Ay1 = Ay0 + Ah + Bx1 = Bx0 + Bw + By1 = By0 + Bh + + # 判断B是否完全在A内部 + return (Bx0 >= Ax0) and (Bx1 <= Ax1) and (By0 >= Ay0) and (By1 <= Ay1) + + +########################################################################## +# 切分算法流程图 +########################################################################## + +def getclusterDict(dotalist,imgheight,imgwidth,pitchSize=1024,max_overlap_rate=0.2): + """ + 生成切片数据 + :param dotalist: 样本集 + :param imgheight: 图像高度 + :param imgwidth: 图像宽度 + :return: 切片类型 + """ + boxs=[] + for i in range(len(dotalist)): + xs=np.array([dotalist[i].x1,dotalist[i].x2,dotalist[i].x3, dotalist[i].x4]) + ys=np.array([dotalist[i].y1,dotalist[i].y2,dotalist[i].y3, dotalist[i].y4]) + x1=np.min(xs) + x2=np.max(xs) + y1=np.min(ys) + y2=np.max(ys) + boxs.append([x1,y1,x2,y2]) # x1, y1, x2, y2 = boxes[i] + + slices, max_covered=find_optimal_slices(imgheight,imgwidth,boxs,pitchSize,max_overlap_rate) + + clusterDict={} + + waitContaindota=[] + hasContainIds=[] + for i in range(len(slices)): + sx,sy=slices[i] + clusterDict[i]={"p":[sx,sy],"id":[]} + slicesExten=[sx,sy,SliceSize,SliceSize] + for ids in range(len(dotalist)): + if ids in hasContainIds: + continue + else: + [centerX, centerY, minX, minY, maxX, maxY]=getExternCenter(dotalist, [ids]) + dotaExtend=[minX,minY,maxX-minX,maxY-minY] + if check_B_in_A(slicesExten,dotaExtend): + clusterDict[i]["id"].append(ids) + hasContainIds.append(ids) + + for ids in range(len(dotalist)): + if ids in hasContainIds: + continue + else: + waitContaindota.append(ids) + print("No in slice dota : ",str(dotalist[ids]) ) + + print("no process ids ",str(waitContaindota)) + return clusterDict + + +def drawSlictplot(clusterDict,dotalist,tiff_data,nrows=10,ncols=9): + """ + :param clusterDict: clusterDict[i]={"p":[sx,sy],"id":[]} + :param dotalist: (x1, y1, x2, y2, x3, y3, x4, y4 clsname diffcule) + :return: + """ + fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(20, 16)) + plt.tight_layout(pad=3.0) + + # 9*10 + subid=0 + for cid in clusterDict: + sx,sy=clusterDict[cid]["p"] + colid=subid//nrows + rowid=subid%nrows + subid=subid+1 + ax = axes[rowid, colid] + ax.set_title(str(cid)) + sliceData=tiff_data[sy:(sy+SliceSize),sx:(sx+SliceSize)] + ax.imshow(sliceData, cmap='gray') + + for did in clusterDict[cid]["id"] : + # 提取x和y坐标 + x_coords = [dotalist[did].x1-sx, dotalist[did].x2-sx, dotalist[did].x3-sx, dotalist[did].x4-sx] + y_coords = [dotalist[did].y1-sy, dotalist[did].y2-sy, dotalist[did].y3-sy, dotalist[did].y4-sy] + + # 计算最小外接矩形(AABB) + x_min, x_max = min(x_coords), max(x_coords) + y_min, y_max = min(y_coords), max(y_coords) + width = x_max - x_min + height = y_max - y_min + + + # 绘制无填充矩形框(仅红色边框) + rect = patches.Rectangle( + (x_min, y_min), width, height, + linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none' + ) + ax.add_patch(rect) + + # 在矩形中心标注目标编号 + center_x = x_min+width/2 + center_y = y_min+height/2 + ax.text(center_x, center_y, str(did), + ha='center', va='center', fontsize=6, color='red') + + + plt.tight_layout() + plt.show() + + print("绘图结束") + return None + + +def slictDataAndOutlabel(clusterDict,dotalist,tiff_data,tiff_basename,outfolderpath,im_geotrans, im_proj): + """ + 切分标签,输出结果与文件 + :param clusterDict: + :param dotalist: + :param tiff_data: + :param tiff_name: + :param outfolderpath: + :return: + """ + for cid in clusterDict: + sx, sy = clusterDict[cid]["p"] + if len(clusterDict[cid]["id"])==0: + continue + sliceData = tiff_data[sy:(sy + SliceSize), sx:(sx + SliceSize)] + outbinname="{}_{}.tiff".format(tiff_basename,cid) + outlabelname="{}_{}.txt".format(tiff_basename,cid) + # 获取样本列表 + outdotalist=[] + for did in clusterDict[cid]["id"] : + x1=dotalist[did].x1-sx + x2=dotalist[did].x2-sx + x3=dotalist[did].x3-sx + x4=dotalist[did].x4-sx + y1=dotalist[did].y1-sy + y2=dotalist[did].y2-sy + y3=dotalist[did].y3-sy + y4=dotalist[did].y4-sy + tempdota=createDota(x1,y1,x2,y2,x3,y3,x4,y4,dotalist[did].clsname,dotalist[did].difficulty) + outdotalist.append(tempdota) + + outlabelpath=os.path.join(outfolderpath,outlabelname) + outbinpath=os.path.join(outfolderpath,outbinname) + + temp_im_geotrans=[tempi for tempi in im_geotrans] + # 处理 0,3 + temp_im_geotrans[0]=im_geotrans[0]+sx*im_geotrans[1]+im_geotrans[2]*sy # x + temp_im_geotrans[3]=im_geotrans[3]+sx*im_geotrans[4]+im_geotrans[5]*sy # y + write_envi(sliceData,temp_im_geotrans,im_proj,outbinpath) + writerDotaFile(outdotalist,outlabelpath) + +########################################################################## +# 处理流程图 +########################################################################## +def DataSampleSliceRasterProcess(inbinfile,labelfilepath,outfolderpath,methodstr,filterlabels): + tiff_proj, tiff_trans, tiff_data = read_tif(inbinfile) + tiff_inv_trans = gdal.InvGeoTransform(tiff_trans) + dotalist=label2pixelpoints(labelfilepath,tiff_inv_trans,methodstr,filterlabels) + imgheight, imgwidth=tiff_data.shape + clusterDict=getclusterDict(dotalist,imgheight,imgwidth,SliceSize,0.25) + # drawSliceRasterPrivew(tiff_data, dotalist, clusterDict) + ncols=int(len(clusterDict)/9+1) + # drawSlictplot(clusterDict, dotalist, tiff_data, 9, ncols) + tiff_name=os.path.basename(inbinfile) + tiff_basename=os.path.splitext(tiff_name)[0] + slictDataAndOutlabel(clusterDict, dotalist, tiff_data, tiff_basename, outfolderpath, tiff_trans, tiff_proj) + + pass + + + +def getParams(): + parser = argparse.ArgumentParser() + parser.add_argument('-i','--inbinfile',type=str,default=r'D:\Annotation_Y\港口\聚束模式\20250505_sp\bc3-sp-org-vv-20250410t053930-020615-000034-005087-01.tiff', help='输入tiff的bin文件') + parser.add_argument('-l', '--labelfilepath',type=str,default=r"D:\Annotation_Y\港口\聚束模式\20250505_sp\bc3-sp-org-vv-20250410t053930-020615-000034-005087-01_LC.txt", help='输入标注') + parser.add_argument('-o', '--outfolderpath',type=str,default=r'D:\Annotation_Y\切分', help='切片文件夹地址') + parser.add_argument('-f', '--filterlabel',type=str,default=r'JLC;MLC', help='标签过滤') + group = parser.add_mutually_exclusive_group() + group.add_argument( + '--geolabel', + action='store_const', + const='geolabel', + dest='method', + help='标注坐标点为地理坐标' + ) + group.add_argument( + '--pixellabel', + action='store_const', + const='pixellabel', + dest='method', + help='标注坐标系统为输入影像的像空间坐标' + ) + parser.set_defaults(method='geolabel') + + + args = parser.parse_args() + return args + +if __name__ == '__main__': + try: + parser = getParams() + inbinfile=parser.inbinfile + labelfilepath=parser.labelfilepath + outfolderpath=parser.outfolderpath + methodstr= parser.method + filterlabels=parser.filterlabel.strip().split(';') + print('inbinfile=',inbinfile) + print('labelfilepath=',labelfilepath) + print('outfolderpath=',outfolderpath) + print('methodstr=',methodstr) + print('filterlabels=',filterlabels) + DataSampleSliceRasterProcess(inbinfile, labelfilepath, outfolderpath,methodstr,filterlabels) + print("样本切分完成") + exit(2) + except Exception as e: + print(e) + exit(3) + + + + + + + + + + + +