""" 2025.09.16 切片增加后缀 _image.png _image.tiff 2025.09.22 增加港口切片要求 """ from osgeo import ogr, gdal import os import argparse import numpy as np from PIL import Image import math from pathlib import Path portsliceSize=5000 shipsliceSize=1024 BlockOverLayer=0.25 def existOrCreate(dirpath): if not os.path.exists(dirpath): os.makedirs(dirpath) def get_filename_without_ext(path): base_name = os.path.basename(path) if '.' not in base_name or base_name.startswith('.'): return base_name return base_name.rsplit('.', 1)[0] def read_tif(path): dataset = gdal.Open(path) # 打开TIF文件 if dataset is None: print("无法打开文件") return None, None, None cols = dataset.RasterXSize # 图像宽度 rows = dataset.RasterYSize # 图像高度 bands = dataset.RasterCount im_proj = dataset.GetProjection() # 获取投影信息 im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 print("行数:", rows) print("列数:", cols) print("波段:", bands) del dataset # 关闭数据集 return im_proj, im_Geotrans, im_data def write_envi(im_data, im_geotrans, im_proj, output_path): """ 将数组数据写入ENVI格式文件 :param im_data: 输入的numpy数组(2D或3D) :param im_geotrans: 仿射变换参数(6元组) :param im_proj: 投影信息(WKT字符串) :param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr) """ im_bands = 1 im_height, im_width = im_data.shape # 创建ENVI格式驱动 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte) if dataset is not None: dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数 dataset.SetProjection(im_proj) # 设置投影 dataset.GetRasterBand(1).WriteArray(im_data) dataset.FlushCache() # 确保数据写入磁盘 dataset = None # 关闭文件 def write_tiff(im_data, im_geotrans, im_proj, output_path): """ 将数组数据写入ENVI格式文件 :param im_data: 输入的numpy数组(2D或3D) :param im_geotrans: 仿射变换参数(6元组) :param im_proj: 投影信息(WKT字符串) :param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr) """ im_bands = 1 im_height, im_width = im_data.shape # 创建ENVI格式驱动 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Float32) if dataset is not None: dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数 dataset.SetProjection(im_proj) # 设置投影 dataset.GetRasterBand(1).WriteArray(im_data) dataset.FlushCache() # 确保数据写入磁盘 dataset = None # 关闭文件 def Strech_linear(im_data): im_data_dB=10*np.log10(im_data) immask=np.isfinite(im_data_dB) infmask = np.isinf(im_data_dB) imvail_data=im_data[immask] im_data_dB=0 minvalue=np.nanmin(imvail_data) maxvalue=np.nanmax(imvail_data) infmask = np.isinf(im_data_dB) im_data[infmask] = minvalue-100 im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254+1 im_data=np.clip(im_data,0,255) return im_data.astype(np.uint8) def Strech_linear1(im_data): im_data_dB = 10 * np.log10(im_data) immask = np.isfinite(im_data_dB) infmask = np.isinf(im_data_dB) imvail_data = im_data[immask] im_data_dB=0 minvalue=np.percentile(imvail_data,1) maxvalue = np.percentile(imvail_data, 99) im_data[infmask] = minvalue - 100 im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 im_data = np.clip(im_data, 0, 255) return im_data.astype(np.uint8) def Strech_linear2(im_data): im_data_dB = 10 * np.log10(im_data) immask = np.isfinite(im_data_dB) infmask = np.isinf(im_data_dB) imvail_data = im_data[immask] im_data_dB = 0 minvalue = np.percentile(imvail_data, 2) maxvalue = np.percentile(imvail_data, 98) im_data[infmask] = minvalue - 100 im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 im_data = np.clip(im_data, 0, 255) return im_data.astype(np.uint8) def Strech_linear5(im_data): im_data_dB = 10 * np.log10(im_data) immask = np.isfinite(im_data_dB) infmask = np.isinf(im_data_dB) imvail_data = im_data[immask] im_data_dB = 0 minvalue = np.percentile(imvail_data, 5) maxvalue = np.percentile(imvail_data, 95) im_data[infmask] = minvalue - 100 im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 im_data = np.clip(im_data, 0, 255) return im_data.astype(np.uint8) def Strech_SquareRoot(im_data): # 判断是否为dB # immask = np.isfinite(im_data) # imvail_data = im_data[immask] # minvalue = np.percentile(imvail_data,30) # if minvalue<0 : # im_data=np.power(10.0,im_data/10.0) im_data=np.sqrt(im_data) immask = np.isfinite(im_data) imvail_data = im_data[immask] minvalue=np.nanmin(imvail_data) maxvalue=np.nanmax(imvail_data) minvalue_01Prec = np.percentile(imvail_data, 2) # 20250904 1%拉伸 maxvalue_999Prec = np.percentile(imvail_data, 98) print('sqrt root min - max ', minvalue,maxvalue) if (maxvalue-minvalue)/(maxvalue_999Prec-minvalue_01Prec)>3: # 表示 拉伸之后,像素值绝大部分很有可能集中在 80 minvalue=minvalue_01Prec maxvalue=maxvalue_999Prec print('sqrt root min(0.1) - max(99.9) ', minvalue, maxvalue) im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 im_data = np.clip(im_data, 0, 255) return im_data.astype(np.uint8) def DataStrech(im_data,strechmethod): # [,"Linear1","Linear2","Linear5","SquareRoot"] if strechmethod == "Linear" : return Strech_linear(im_data) elif strechmethod == "Linear1": return Strech_linear1(im_data) elif strechmethod == "Linear2": return Strech_linear2(im_data) elif strechmethod == "Linear5": return Strech_linear5(im_data) elif strechmethod == "SquareRoot": return Strech_SquareRoot(im_data) else: return im_data.astype(np.uint8) # 文件模式 def stretchProcess(infilepath,outfilepath,strechmethod): im_proj, im_Geotrans, im_data=read_tif(infilepath) envifilepath=get_filename_without_ext(outfilepath)+".bin" envifilepath=os.path.join(os.path.dirname(outfilepath),envifilepath) im_data = DataStrech(im_data,strechmethod) im_data = im_data.astype(np.uint8) write_envi(im_data,im_Geotrans,im_proj,envifilepath) Image.fromarray(im_data).save(outfilepath,compress_level=0) print("图像拉伸处理结束") def getsliceGeotrans(GeoTransform,Xpixel,Ypixel): XGeo = GeoTransform[0]+GeoTransform[1]*Xpixel+GeoTransform[2]*Ypixel YGeo = GeoTransform[3]+GeoTransform[4]*Xpixel+GeoTransform[5]*Ypixel result=[ XGeo,GeoTransform[1],GeoTransform[2], YGeo,GeoTransform[4],GeoTransform[5] ] return result def is_all_same(lst): arr = np.array(lst) # arr_num=arr.size sum_data=np.sum(arr != arr[0]) return sum_data<400 def getNextSliceNumber(n,sliceSize,overlap=0.25): step=int(sliceSize*(1-overlap))+1 ti = list(range(0, n, step)) newN= n if ti[-1]+1024 < n else ti[-1]+1024 # 评价重叠率 movelayer=[] for i in range(len(ti)-1): movelayer.append((ti[i] + 1024 - ti[i + 1]) / 1024 * 100.0) print("重叠率:",movelayer) return newN,ti def sliceShipDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder): binfolder=os.path.join(outfolder,"舰船","unit8binfolder") pngfolder=os.path.join(outfolder,"舰船","pngfolder") tifffolder=os.path.join(outfolder,"舰船","tifffolder") h,w=im_data.shape nextH,ht=getNextSliceNumber(h,shipsliceSize,BlockOverLayer) nextW,wt=getNextSliceNumber(w,shipsliceSize,BlockOverLayer) padH=nextH-h padW=nextW-w im_data=np.pad(im_data,((0,padH),(0,padW)),mode='constant',constant_values=0) src_im_data=np.pad(src_im_data,((0,padH),(0,padW)),mode='constant',constant_values=0) slice_ID=0 for hi in ht: for wi in wt: geotrans_temp=getsliceGeotrans(im_Geotrans,wi,hi) im_data_temp=im_data[hi:hi+shipsliceSize,wi:wi+shipsliceSize] src_im_data_temp=src_im_data[hi:hi+shipsliceSize,wi:wi+shipsliceSize] slice_ID = slice_ID + 1 if not is_all_same(im_data_temp): sliceBinPath=os.path.join(binfolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.tiff") slicepngPath=os.path.join(pngfolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.png") slicesrctiffPath=os.path.join(tifffolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.tiff") write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath) write_envi(im_data_temp,geotrans_temp,im_proj,sliceBinPath) Image.fromarray(im_data_temp).save(slicepngPath,compress_level=0) print("图像切片结束") return slice_ID def ishasPort(im_Geotrans,im_data,MLCPoints,JLCPoints,MJLCPoints): LCpoints=MLCPoints+JLCPoints+MJLCPoints # 获取范围 rows=im_data.shape[0] cols=im_data.shape[1] x1=im_Geotrans[0]+im_Geotrans[1]*0 x2=im_Geotrans[0]+im_Geotrans[1]*cols y1=im_Geotrans[3]+im_Geotrans[5]*0 y2=im_Geotrans[3]+im_Geotrans[5]*rows xmin=min(x1,x2) xmax=max(x1,x2) ymin=min(y1,y2) ymax=max(y1,y2) # 数据处理 for p in LCpoints: x_in_range = (p[0] >= xmin) & (p[0]<= xmax) y_in_range = (p[1] >= ymin) & (p[1] <= ymax) within_rect_indices_mask = x_in_range & y_in_range if within_rect_indices_mask: return True return False def slicePortDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder,slice_ID,portfilestr): # 读取portfilestr 中文件 MLCPoints=[] JLCPoints=[] MJLCPoints=[] with open(portfilestr,"r",encoding="utf-8") as portfile: for line in portfile.readlines(): if(len(line)>3): linemetas=line.split("\t\t") clsname=linemetas[0] pointstr=linemetas[1] pointx=float(pointstr.split(",")[0]) pointy=float(pointstr.split(",")[1]) if clsname=="JLC": JLCPoints.append([pointx,pointy]) elif clsname=="MJLC": MJLCPoints.append([pointx,pointy]) elif clsname=="MLC": MLCPoints.append([pointx,pointy]) # 处理文件脚本 if len(MLCPoints)==0 and len(JLCPoints)==0 and len(MJLCPoints)==0: return else: pass # 切片主流程 binfolder=os.path.join(outfolder,"港口","unit8binfolder") pngfolder=os.path.join(outfolder,"港口","pngfolder") tifffolder=os.path.join(outfolder,"港口","tifffolder") h,w=im_data.shape nextH,ht=getNextSliceNumber(h,portsliceSize,BlockOverLayer) nextW,wt=getNextSliceNumber(w,portsliceSize,BlockOverLayer) padH=nextH-h padW=nextW-w im_data=np.pad(im_data,((0,padH),(0,padW)),mode='constant',constant_values=0) src_im_data=np.pad(src_im_data,((0,padH),(0,padW)),mode='constant',constant_values=0) for hi in ht: for wi in wt: geotrans_temp=getsliceGeotrans(im_Geotrans,wi,hi) im_data_temp=im_data[hi:hi+portsliceSize,wi:wi+portsliceSize] src_im_data_temp=src_im_data[hi:hi+portsliceSize,wi:wi+portsliceSize] slice_ID = slice_ID + 1 if ishasPort(geotrans_temp,src_im_data_temp,MLCPoints,JLCPoints,MJLCPoints): sliceBinPath=os.path.join(binfolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.tiff") slicepngPath=os.path.join(pngfolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.png") slicesrctiffPath=os.path.join(tifffolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.tiff") write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath) write_envi(im_data_temp,geotrans_temp,im_proj,sliceBinPath) Image.fromarray(im_data_temp).save(slicepngPath,compress_level=0) print("图像切片结束") return slice_ID def sliceLabelPortDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder,slice_ID,portfilestr): # 读取portfilestr 中文件 MLCPoints=[] JLCPoints=[] MJLCPoints=[] with open(portfilestr,"r",encoding="utf-8") as portfile: for line in portfile.readlines(): if(len(line)>3): linemetas=line.split("\t\t") clsname=linemetas[0] pointstr=linemetas[1] pointx=float(pointstr.split(",")[0]) pointy=float(pointstr.split(",")[1]) if clsname=="JLC": JLCPoints.append([pointx,pointy]) elif clsname=="MJLC": MJLCPoints.append([pointx,pointy]) elif clsname=="MLC": MLCPoints.append([pointx,pointy]) # 处理文件脚本 if len(MLCPoints)==0 and len(JLCPoints)==0 and len(MJLCPoints)==0: return else: pass # 切片主流程 portuint8Tifffolder=os.path.join(outfolder,"港口","unit8tiff") portlabelfolder=os.path.join(outfolder,"港口","MLCLabels") unit8tiffPath=os.path.join(portuint8Tifffolder,"{}_uint8.tiff".format(rootname)) uint8labelPath=os.path.join(portlabelfolder,"{}_uint8.csv".format(rootname)) write_envi(im_data,im_Geotrans,im_proj,unit8tiffPath) with open(portfilestr,"r",encoding="utf-8") as portfile: with open(uint8labelPath, "w", encoding="utf-8") as labelfile: for line in portfile.readlines(): if (len(line) > 3): linemetas = line.split("\t\t") clsname = linemetas[0] pointstr = linemetas[1] pointx = float(pointstr.split(",")[0]) pointy = float(pointstr.split(",")[1]) labelfile.write("{},{},{}\n".format(pointx,pointy,clsname)) return None def stretchSliceProcess(infilepath, outfolder,portfilestr, strechmethod): shipbinfolder=os.path.join(outfolder,"舰船","unit8binfolder") shippngfolder=os.path.join(outfolder,"舰船","pngfolder") shiptifffolder=os.path.join(outfolder,"舰船","tifffolder") allpngfolder = os.path.join(outfolder, "allpngfolder") # portbinfolder=os.path.join(outfolder,"港口","unit8binfolder") # portpngfolder=os.path.join(outfolder,"港口","pngfolder") # porttifffolder=os.path.join(outfolder,"港口","tifffolder") portuint8Tifffolder=os.path.join(outfolder,"港口","unit8tiff") portlabelfolder=os.path.join(outfolder,"港口","PortLabels") existOrCreate(shipbinfolder) existOrCreate(shippngfolder) existOrCreate(shiptifffolder) existOrCreate(allpngfolder) # existOrCreate(portbinfolder) # existOrCreate(portpngfolder) # existOrCreate(porttifffolder) existOrCreate(portuint8Tifffolder) existOrCreate(portlabelfolder) im_proj, im_Geotrans, im_data=read_tif(infilepath) src_im_data=im_data*1.0 im_data = DataStrech(im_data,strechmethod) # 拉伸 im_data = im_data.astype(np.uint8) rootname=Path(infilepath).stem allImagePath=os.path.join(allpngfolder, rootname+"_all.png") Image.fromarray(im_data).save(allImagePath,compress_level=0) slice_ID=0 slice_ID=sliceShipDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder) # 舰船切片 slice_ID=slice_ID+1 # slice_ID=slicePortDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder,slice_ID,portfilestr) slice_ID=sliceLabelPortDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder,slice_ID,portfilestr) # 港口拉伸 print("图像切片与拉伸完成") pass def getParams(): parser = argparse.ArgumentParser() parser.add_argument('-i','--infile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.tiff", help='输入shapefile文件') parser.add_argument('-p', '--portfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.txt", help='输出geojson文件') parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\切片结果", help='输出geojson文件') group = parser.add_mutually_exclusive_group() group.add_argument( '--filemode', action='store_const', const='filemode', dest='mode', help='文件模式' ) group.add_argument( '--slicemode', action='store_const', const='slicemode', dest='mode', help='切片模式' ) parser.set_defaults(mode='slicemode') group = parser.add_mutually_exclusive_group() group.add_argument( '--Linear', action='store_const', const='Linear', dest='method', help='线性拉伸' ) group.add_argument( '--Linear1prec', action='store_const', const='Linear1', dest='method', help='1%线性拉伸' ) group.add_argument( '--Linear2prec', action='store_const', const='Linear2', dest='method', help='2%线性拉伸' ) group.add_argument( '--Linear5prec', action='store_const', const='Linear5', dest='method', help='5%线性拉伸' ) group.add_argument( '--SquareRoot', action='store_const', const='SquareRoot', dest='method', help='平方根拉伸' ) parser.set_defaults(method='SquareRoot') args = parser.parse_args() return args if __name__ == '__main__': try: parser = getParams() intiffPath=parser.infile modestr=parser.mode methodstr = parser.method if modestr == "filemode": outbinPath = parser.outfile print('infile=', intiffPath) print('outfile=', outbinPath) print('method=', methodstr) stretchProcess(intiffPath, outbinPath, methodstr) elif modestr == "slicemode": outfolder = parser.outfile portfilestr = parser.portfile print('infile=', intiffPath) print('outfolder=', outfolder) print('method=', methodstr) print('portfile=', portfilestr) stretchSliceProcess(intiffPath, outfolder,portfilestr, methodstr) pass else: print("模式错误") exit(2) except Exception as e: print(e) exit(3)