SpacetySliceTools/generatorRasterSlicesTools/SpacetyTIFFDataStretch2PNG_...

587 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
2025.09.16 切片增加后缀 _image.png _image.tiff
2025.09.22 增加港口切片要求
"""
from osgeo import ogr, gdal
import os
import argparse
import numpy as np
from PIL import Image
import math
from pathlib import Path
portsliceSize=5000
shipsliceSize=1024
BlockOverLayer=0.25
def existOrCreate(dirpath):
if not os.path.exists(dirpath):
os.makedirs(dirpath)
def get_filename_without_ext(path):
base_name = os.path.basename(path)
if '.' not in base_name or base_name.startswith('.'):
return base_name
return base_name.rsplit('.', 1)[0]
def read_tif(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
del dataset # 关闭数据集
return im_proj, im_Geotrans, im_data
def write_envi(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
def write_tiff(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Float32)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
def Strech_linear(im_data):
im_data_dB=10*np.log10(im_data)
immask=np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data=im_data[immask]
im_data_dB=0
minvalue=np.nanmin(imvail_data)
maxvalue=np.nanmax(imvail_data)
infmask = np.isinf(im_data_dB)
im_data[infmask] = minvalue-100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254+1
im_data=np.clip(im_data,0,255)
return im_data.astype(np.uint8)
def Strech_linear1(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB=0
minvalue=np.percentile(imvail_data,1)
maxvalue = np.percentile(imvail_data, 99)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_linear2(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB = 0
minvalue = np.percentile(imvail_data, 2)
maxvalue = np.percentile(imvail_data, 98)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_linear5(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB = 0
minvalue = np.percentile(imvail_data, 5)
maxvalue = np.percentile(imvail_data, 95)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_SquareRoot(im_data):
# 判断是否为dB
# immask = np.isfinite(im_data)
# imvail_data = im_data[immask]
# minvalue = np.percentile(imvail_data,30)
# if minvalue<0 :
# im_data=np.power(10.0,im_data/10.0)
im_data=np.sqrt(im_data)
immask = np.isfinite(im_data)
imvail_data = im_data[immask]
minvalue=np.nanmin(imvail_data)
maxvalue=np.nanmax(imvail_data)
minvalue_01Prec = np.percentile(imvail_data, 2) # 20250904 1%拉伸
maxvalue_999Prec = np.percentile(imvail_data, 98)
print('sqrt root min - max ', minvalue,maxvalue)
if (maxvalue-minvalue)/(maxvalue_999Prec-minvalue_01Prec)>3: # 表示 拉伸之后,像素值绝大部分很有可能集中在 80
minvalue=minvalue_01Prec
maxvalue=maxvalue_999Prec
print('sqrt root min(0.1) - max(99.9) ', minvalue, maxvalue)
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def DataStrech(im_data,strechmethod):
# [,"Linear1","Linear2","Linear5","SquareRoot"]
if strechmethod == "Linear" :
return Strech_linear(im_data)
elif strechmethod == "Linear1":
return Strech_linear1(im_data)
elif strechmethod == "Linear2":
return Strech_linear2(im_data)
elif strechmethod == "Linear5":
return Strech_linear5(im_data)
elif strechmethod == "SquareRoot":
return Strech_SquareRoot(im_data)
else:
return im_data.astype(np.uint8)
# 文件模式
def stretchProcess(infilepath,outfilepath,strechmethod):
im_proj, im_Geotrans, im_data=read_tif(infilepath)
envifilepath=get_filename_without_ext(outfilepath)+".bin"
envifilepath=os.path.join(os.path.dirname(outfilepath),envifilepath)
im_data = DataStrech(im_data,strechmethod)
im_data = im_data.astype(np.uint8)
write_envi(im_data,im_Geotrans,im_proj,envifilepath)
Image.fromarray(im_data).save(outfilepath,compress_level=0)
print("图像拉伸处理结束")
def getsliceGeotrans(GeoTransform,Xpixel,Ypixel):
XGeo = GeoTransform[0]+GeoTransform[1]*Xpixel+GeoTransform[2]*Ypixel
YGeo = GeoTransform[3]+GeoTransform[4]*Xpixel+GeoTransform[5]*Ypixel
result=[
XGeo,GeoTransform[1],GeoTransform[2],
YGeo,GeoTransform[4],GeoTransform[5]
]
return result
def is_all_same(lst):
arr = np.array(lst)
# arr_num=arr.size
sum_data=np.sum(arr != arr[0])
return sum_data<400
def getNextSliceNumber(n,sliceSize,overlap=0.25):
step=int(sliceSize*(1-overlap))+1
ti = list(range(0, n, step))
newN= n if ti[-1]+1024 < n else ti[-1]+1024
# 评价重叠率
movelayer=[]
for i in range(len(ti)-1):
movelayer.append((ti[i] + 1024 - ti[i + 1]) / 1024 * 100.0)
print("重叠率:",movelayer)
return newN,ti
def sliceShipDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder):
binfolder=os.path.join(outfolder,"舰船","unit8binfolder")
pngfolder=os.path.join(outfolder,"舰船","pngfolder")
tifffolder=os.path.join(outfolder,"舰船","tifffolder")
h,w=im_data.shape
nextH,ht=getNextSliceNumber(h,shipsliceSize,BlockOverLayer)
nextW,wt=getNextSliceNumber(w,shipsliceSize,BlockOverLayer)
padH=nextH-h
padW=nextW-w
im_data=np.pad(im_data,((0,padH),(0,padW)),mode='constant',constant_values=0)
src_im_data=np.pad(src_im_data,((0,padH),(0,padW)),mode='constant',constant_values=0)
slice_ID=0
for hi in ht:
for wi in wt:
geotrans_temp=getsliceGeotrans(im_Geotrans,wi,hi)
im_data_temp=im_data[hi:hi+shipsliceSize,wi:wi+shipsliceSize]
src_im_data_temp=src_im_data[hi:hi+shipsliceSize,wi:wi+shipsliceSize]
slice_ID = slice_ID + 1
if not is_all_same(im_data_temp):
sliceBinPath=os.path.join(binfolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.tiff")
slicepngPath=os.path.join(pngfolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.png")
slicesrctiffPath=os.path.join(tifffolder, rootname+"_"+str(slice_ID).zfill(4)+"_image.tiff")
write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath)
write_envi(im_data_temp,geotrans_temp,im_proj,sliceBinPath)
Image.fromarray(im_data_temp).save(slicepngPath,compress_level=0)
print("图像切片结束")
return slice_ID
def ishasPort(im_Geotrans,im_data,MLCPoints,JLCPoints,MJLCPoints):
LCpoints=MLCPoints+JLCPoints+MJLCPoints
# 获取范围
rows=im_data.shape[0]
cols=im_data.shape[1]
x1=im_Geotrans[0]+im_Geotrans[1]*0
x2=im_Geotrans[0]+im_Geotrans[1]*cols
y1=im_Geotrans[3]+im_Geotrans[5]*0
y2=im_Geotrans[3]+im_Geotrans[5]*rows
xmin=min(x1,x2)
xmax=max(x1,x2)
ymin=min(y1,y2)
ymax=max(y1,y2)
# 数据处理
for p in LCpoints:
x_in_range = (p[0] >= xmin) & (p[0]<= xmax)
y_in_range = (p[1] >= ymin) & (p[1] <= ymax)
within_rect_indices_mask = x_in_range & y_in_range
if within_rect_indices_mask:
return True
return False
def slicePortDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder,slice_ID,portfilestr):
# 读取portfilestr 中文件
MLCPoints=[]
JLCPoints=[]
MJLCPoints=[]
portfilelines=[]
with open(portfilestr,"r",encoding="utf-8") as portfile:
portfilelines=portfile.readlines()
for line in portfilelines:
if(len(line)>3):
linemetas=line.split("\t\t")
clsname=linemetas[0]
pointstr=linemetas[1]
pointx=float(pointstr.split(",")[0]) # PX
pointy=float(pointstr.split(",")[1]) # Py
if clsname=="JLC":
JLCPoints.append([pointx,pointy])
elif clsname=="MJLC":
MJLCPoints.append([pointx,pointy])
elif clsname=="MLC":
MLCPoints.append([pointx,pointy])
# 处理文件脚本
if len(MLCPoints)==0 and len(JLCPoints)==0 and len(MJLCPoints)==0:
return
else:
pass
# 切片主流程
binfolder=os.path.join(outfolder,"港口","unit8binfolder")
pngfolder=os.path.join(outfolder,"港口","pngfolder")
tifffolder=os.path.join(outfolder,"港口","tifffolder")
for P in JLCPoints:
Px=P[0]
Py=P[1]
Sx=P[0]-portsliceSize/2
Sy=P[1]-portsliceSize/2
wi=Sx if Sx>0 else 0
hi=Sy if Sy>0 else 0
slice_ID = slice_ID + 1
im_data_temp = im_data[hi:hi + portsliceSize, wi:wi + portsliceSize]
src_im_data_temp = src_im_data[hi:hi + portsliceSize, wi:wi + portsliceSize]
sliceBinPath = os.path.join(binfolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.tiff")
slicepngPath = os.path.join(pngfolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.png")
slicesrctiffPath = os.path.join(tifffolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.tiff")
geotrans_temp = getsliceGeotrans(im_Geotrans, wi, hi)
write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath)
write_envi(im_data_temp, geotrans_temp, im_proj, sliceBinPath)
Image.fromarray(im_data_temp).save(slicepngPath, compress_level=4)
for P in MJLCPoints:
Px = P[0]
Py = P[1]
Sx = P[0] - portsliceSize / 2
Sy = P[1] - portsliceSize / 2
wi = Sx if Sx > 0 else 0
hi = Sy if Sy > 0 else 0
slice_ID = slice_ID + 1
im_data_temp = im_data[hi:hi + portsliceSize, wi:wi + portsliceSize]
src_im_data_temp = src_im_data[hi:hi + portsliceSize, wi:wi + portsliceSize]
sliceBinPath = os.path.join(binfolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.tiff")
slicepngPath = os.path.join(pngfolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.png")
slicesrctiffPath = os.path.join(tifffolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.tiff")
geotrans_temp = getsliceGeotrans(im_Geotrans, wi, hi)
write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath)
write_envi(im_data_temp, geotrans_temp, im_proj, sliceBinPath)
Image.fromarray(im_data_temp).save(slicepngPath, compress_level=4)
for P in MLCPoints:
Px = P[0]
Py = P[1]
Sx = P[0] - portsliceSize / 2
Sy = P[1] - portsliceSize / 2
wi = Sx if Sx > 0 else 0
hi = Sy if Sy > 0 else 0
slice_ID = slice_ID + 1
im_data_temp = im_data[hi:hi + portsliceSize, wi:wi + portsliceSize]
src_im_data_temp = src_im_data[hi:hi + portsliceSize, wi:wi + portsliceSize]
sliceBinPath = os.path.join(binfolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.tiff")
slicepngPath = os.path.join(pngfolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.png")
slicesrctiffPath = os.path.join(tifffolder, rootname + "_" + str(slice_ID).zfill(4) + "_image.tiff")
geotrans_temp = getsliceGeotrans(im_Geotrans, wi, hi)
write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath)
write_envi(im_data_temp, geotrans_temp, im_proj, sliceBinPath)
Image.fromarray(im_data_temp).save(slicepngPath, compress_level=4)
print("图像切片结束")
return slice_ID
def sliceLabelPortDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder,slice_ID,portfilestr):
# 读取portfilestr 中文件
MLCPoints=[]
JLCPoints=[]
MJLCPoints=[]
with open(portfilestr,"r",encoding="utf-8") as portfile:
for line in portfile.readlines():
if(len(line)>3):
linemetas=line.split("\t\t")
clsname=linemetas[0]
pointstr=linemetas[1]
pointx=float(pointstr.split(",")[0])
pointy=float(pointstr.split(",")[1])
if clsname=="JLC":
JLCPoints.append([pointx,pointy])
elif clsname=="MJLC":
MJLCPoints.append([pointx,pointy])
elif clsname=="MLC":
MLCPoints.append([pointx,pointy])
# 处理文件脚本
if len(MLCPoints)==0 and len(JLCPoints)==0 and len(MJLCPoints)==0:
return
else:
pass
# 切片主流程
portuint8Tifffolder=os.path.join(outfolder,"港口","unit8tiff")
portlabelfolder=os.path.join(outfolder,"港口","MLCLabels")
unit8tiffPath=os.path.join(portuint8Tifffolder,"{}_uint8.tiff".format(rootname))
uint8labelPath=os.path.join(portlabelfolder,"{}_uint8.csv".format(rootname))
write_envi(im_data,im_Geotrans,im_proj,unit8tiffPath)
with open(portfilestr,"r",encoding="utf-8") as portfile:
with open(uint8labelPath, "w", encoding="utf-8") as labelfile:
for line in portfile.readlines():
if (len(line) > 3):
linemetas = line.split("\t\t")
clsname = linemetas[0]
pointstr = linemetas[1]
pointx = float(pointstr.split(",")[0])
pointy = float(pointstr.split(",")[1])
labelfile.write("{},{},{}\n".format(pointx,pointy,clsname))
return None
def stretchSliceProcess(infilepath, outfolder,portfilestr, strechmethod):
shipbinfolder=os.path.join(outfolder,"舰船","unit8binfolder")
shippngfolder=os.path.join(outfolder,"舰船","pngfolder")
shiptifffolder=os.path.join(outfolder,"舰船","tifffolder")
allpngfolder = os.path.join(outfolder, "allpngfolder")
portbinfolder=os.path.join(outfolder,"港口","unit8binfolder")
portpngfolder=os.path.join(outfolder,"港口","pngfolder")
porttifffolder=os.path.join(outfolder,"港口","tifffolder")
portuint8Tifffolder=os.path.join(outfolder,"港口","unit8tiff")
portlabelfolder=os.path.join(outfolder,"港口","PortLabels")
existOrCreate(shipbinfolder)
existOrCreate(shippngfolder)
existOrCreate(shiptifffolder)
existOrCreate(allpngfolder)
existOrCreate(portbinfolder)
existOrCreate(portpngfolder)
existOrCreate(porttifffolder)
existOrCreate(portuint8Tifffolder)
existOrCreate(portlabelfolder)
im_proj, im_Geotrans, im_data=read_tif(infilepath)
src_im_data=im_data*1.0
im_data = DataStrech(im_data,strechmethod) # 拉伸
im_data = im_data.astype(np.uint8)
rootname=Path(infilepath).stem
allImagePath=os.path.join(allpngfolder, rootname+"_all.png")
Image.fromarray(im_data).save(allImagePath,compress_level=0)
slice_ID=0
slice_ID=sliceShipDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder) # 舰船切片
slice_ID=slice_ID+1
slice_ID=slicePortDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder,slice_ID,portfilestr)
# slice_ID=sliceLabelPortDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder,slice_ID,portfilestr) # 港口拉伸
print("图像切片与拉伸完成")
pass
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--infile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.tiff", help='输入shapefile文件')
parser.add_argument('-p', '--portfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.txt", help='输出geojson文件')
parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\切片结果", help='输出geojson文件')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--filemode',
action='store_const',
const='filemode',
dest='mode',
help='文件模式'
)
group.add_argument(
'--slicemode',
action='store_const',
const='slicemode',
dest='mode',
help='切片模式'
)
parser.set_defaults(mode='slicemode')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--Linear',
action='store_const',
const='Linear',
dest='method',
help='线性拉伸'
)
group.add_argument(
'--Linear1prec',
action='store_const',
const='Linear1',
dest='method',
help='1%线性拉伸'
)
group.add_argument(
'--Linear2prec',
action='store_const',
const='Linear2',
dest='method',
help='2%线性拉伸'
)
group.add_argument(
'--Linear5prec',
action='store_const',
const='Linear5',
dest='method',
help='5%线性拉伸'
)
group.add_argument(
'--SquareRoot',
action='store_const',
const='SquareRoot',
dest='method',
help='平方根拉伸'
)
parser.set_defaults(method='SquareRoot')
args = parser.parse_args()
return args
if __name__ == '__main__':
try:
parser = getParams()
intiffPath=parser.infile
modestr=parser.mode
methodstr = parser.method
if modestr == "filemode":
outbinPath = parser.outfile
print('infile=', intiffPath)
print('outfile=', outbinPath)
print('method=', methodstr)
stretchProcess(intiffPath, outbinPath, methodstr)
elif modestr == "slicemode":
outfolder = parser.outfile
portfilestr = parser.portfile
print('infile=', intiffPath)
print('outfolder=', outfolder)
print('method=', methodstr)
print('portfile=', portfilestr)
stretchSliceProcess(intiffPath, outfolder,portfilestr, methodstr)
pass
else:
print("模式错误")
exit(2)
except Exception as e:
print(e)
exit(3)