SpacetySliceTools/tools/SpacetyTIFFDataStretch2PNG.py

397 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from osgeo import ogr, gdal
import os
import argparse
import numpy as np
from PIL import Image
import math
from pathlib import Path
sliceSize=1024
BlockOverLayer=0.25
def get_filename_without_ext(path):
base_name = os.path.basename(path)
if '.' not in base_name or base_name.startswith('.'):
return base_name
return base_name.rsplit('.', 1)[0]
def read_tif(path):
dataset = gdal.Open(path) # 打开TIF文件
if dataset is None:
print("无法打开文件")
return None, None, None
cols = dataset.RasterXSize # 图像宽度
rows = dataset.RasterYSize # 图像高度
bands = dataset.RasterCount
im_proj = dataset.GetProjection() # 获取投影信息
im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息
im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组
print("行数:", rows)
print("列数:", cols)
print("波段:", bands)
del dataset # 关闭数据集
return im_proj, im_Geotrans, im_data
def write_envi(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
def write_tiff(im_data, im_geotrans, im_proj, output_path):
"""
将数组数据写入ENVI格式文件
:param im_data: 输入的numpy数组2D或3D
:param im_geotrans: 仿射变换参数6元组
:param im_proj: 投影信息WKT字符串
:param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr
"""
im_bands = 1
im_height, im_width = im_data.shape
# 创建ENVI格式驱动
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Float32)
if dataset is not None:
dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数
dataset.SetProjection(im_proj) # 设置投影
dataset.GetRasterBand(1).WriteArray(im_data)
dataset.FlushCache() # 确保数据写入磁盘
dataset = None # 关闭文件
def Strech_linear(im_data):
im_data_dB=10*np.log10(im_data)
immask=np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data=im_data[immask]
im_data_dB=0
minvalue=np.nanmin(imvail_data)
maxvalue=np.nanmax(imvail_data)
infmask = np.isinf(im_data_dB)
im_data[infmask] = minvalue-100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254+1
im_data=np.clip(im_data,0,255)
return im_data.astype(np.uint8)
def Strech_linear1(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB=0
minvalue=np.percentile(imvail_data,1)
maxvalue = np.percentile(imvail_data, 99)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_linear2(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB = 0
minvalue = np.percentile(imvail_data, 2)
maxvalue = np.percentile(imvail_data, 98)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_linear5(im_data):
im_data_dB = 10 * np.log10(im_data)
immask = np.isfinite(im_data_dB)
infmask = np.isinf(im_data_dB)
imvail_data = im_data[immask]
im_data_dB = 0
minvalue = np.percentile(imvail_data, 5)
maxvalue = np.percentile(imvail_data, 95)
im_data[infmask] = minvalue - 100
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def Strech_SquareRoot(im_data):
# 判断是否为dB
# immask = np.isfinite(im_data)
# imvail_data = im_data[immask]
# minvalue = np.percentile(imvail_data,30)
# if minvalue<0 :
# im_data=np.power(10.0,im_data/10.0)
im_data=np.sqrt(im_data)
immask = np.isfinite(im_data)
imvail_data = im_data[immask]
minvalue=np.nanmin(imvail_data)
maxvalue=np.nanmax(imvail_data)
minvalue_01Prec = np.percentile(imvail_data, 0.1) # 20250904 1%拉伸
maxvalue_999Prec = np.percentile(imvail_data, 99.9)
print('sqrt root min - max ', minvalue,maxvalue)
if (maxvalue-minvalue)/(maxvalue_999Prec-minvalue_01Prec)>3: # 表示 拉伸之后,像素值绝大部分很有可能集中在 80
minvalue=minvalue_01Prec
maxvalue=maxvalue_999Prec
print('sqrt root min(0.1) - max(99.9) ', minvalue, maxvalue)
im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1
im_data = np.clip(im_data, 0, 255)
return im_data.astype(np.uint8)
def DataStrech(im_data,strechmethod):
# [,"Linear1","Linear2","Linear5","SquareRoot"]
if strechmethod == "Linear" :
return Strech_linear(im_data)
elif strechmethod == "Linear1":
return Strech_linear1(im_data)
elif strechmethod == "Linear2":
return Strech_linear2(im_data)
elif strechmethod == "Linear5":
return Strech_linear5(im_data)
elif strechmethod == "SquareRoot":
return Strech_SquareRoot(im_data)
else:
return im_data.astype(np.uint8)
# 文件模式
def stretchProcess(infilepath,outfilepath,strechmethod):
im_proj, im_Geotrans, im_data=read_tif(infilepath)
envifilepath=get_filename_without_ext(outfilepath)+".bin"
envifilepath=os.path.join(os.path.dirname(outfilepath),envifilepath)
im_data = DataStrech(im_data,strechmethod)
im_data = im_data.astype(np.uint8)
write_envi(im_data,im_Geotrans,im_proj,envifilepath)
Image.fromarray(im_data).save(outfilepath,compress_level=0)
print("图像拉伸处理结束")
#切片模式
def getSlicePoints(h):
n = int(math.floor((h - 1024) * 1.2 / sliceSize))
step=int(math.ceil((h-1024)/n))
ti=list(range(0,h-1024,step))
ti.append(h-1024)
# 评价重叠率
movelayer=[]
for i in range(len(ti)-1):
movelayer.append((ti[i] + 1024 - ti[i + 1]) / 1024 * 100.0)
print("重叠率:",movelayer)
return ti
def getsliceGeotrans(GeoTransform,Xpixel,Ypixel):
XGeo = GeoTransform[0]+GeoTransform[1]*Xpixel+GeoTransform[2]*Ypixel
YGeo = GeoTransform[3]+GeoTransform[4]*Xpixel+GeoTransform[5]*Ypixel
result=[
XGeo,GeoTransform[1],GeoTransform[2],
YGeo,GeoTransform[4],GeoTransform[5]
]
return result
def is_all_same(lst):
arr = np.array(lst)
# arr_num=arr.size
sum_data=np.sum(arr != arr[0])
return sum_data<400
def getNextSliceNumber(n,sliceSize,overlap=0.25):
step=int(sliceSize*(1-overlap))+1
ti = list(range(0, n, step))
newN= n if ti[-1]+1024 < n else ti[-1]+1024
# 评价重叠率
movelayer=[]
for i in range(len(ti)-1):
movelayer.append((ti[i] + 1024 - ti[i + 1]) / 1024 * 100.0)
print("重叠率:",movelayer)
return newN,ti
def sliceDataset(rootname,im_data,src_im_data, im_Geotrans, im_proj, outfolder):
binfolder=os.path.join(outfolder,"unit8binfolder")
pngfolder=os.path.join(outfolder,"pngfolder")
tifffolder=os.path.join(outfolder,"tifffolder")
h,w=im_data.shape
nextH,ht=getNextSliceNumber(h,sliceSize,BlockOverLayer)
nextW,wt=getNextSliceNumber(w,sliceSize,BlockOverLayer)
padH=nextH-h
padW=nextW-w
im_data=np.pad(im_data,((0,padH),(0,padW)),mode='constant',constant_values=0)
src_im_data=np.pad(src_im_data,((0,padH),(0,padW)),mode='constant',constant_values=0)
slice_ID=0
for hi in ht:
for wi in wt:
geotrans_temp=getsliceGeotrans(im_Geotrans,wi,hi)
im_data_temp=im_data[hi:hi+1024,wi:wi+1024]
src_im_data_temp=src_im_data[hi:hi+1024,wi:wi+1024]
slice_ID = slice_ID + 1
if not is_all_same(im_data_temp):
sliceBinPath=os.path.join(binfolder, rootname+"_"+str(slice_ID).zfill(4)+".tiff")
slicepngPath=os.path.join(pngfolder, rootname+"_"+str(slice_ID).zfill(4)+".png")
slicesrctiffPath=os.path.join(tifffolder, rootname+"_"+str(slice_ID).zfill(4)+".tiff")
write_tiff(src_im_data_temp, geotrans_temp, im_proj, slicesrctiffPath)
write_envi(im_data_temp,geotrans_temp,im_proj,sliceBinPath)
Image.fromarray(im_data_temp).save(slicepngPath,compress_level=0)
print("图像切片结束")
def stretchSliceProcess(infilepath, outfolder, strechmethod):
binfolder=os.path.join(outfolder,"unit8binfolder")
pngfolder=os.path.join(outfolder,"pngfolder")
tifffolder=os.path.join(outfolder,"tifffolder")
if not os.path.exists(binfolder):
os.makedirs(binfolder)
if not os.path.exists(pngfolder):
os.makedirs(pngfolder)
if not os.path.exists(tifffolder):
os.makedirs(tifffolder)
im_proj, im_Geotrans, im_data=read_tif(infilepath)
src_im_data=im_data*1.0
im_data = DataStrech(im_data,strechmethod) # 拉伸
im_data = im_data.astype(np.uint8)
rootname=Path(infilepath).stem
allImagePath=os.path.join(outfolder, rootname+"_all.png")
Image.fromarray(im_data).save(allImagePath,compress_level=0)
sliceDataset(rootname,im_data, src_im_data,im_Geotrans, im_proj, outfolder)
print("图像切片与拉伸完成")
pass
def getParams():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--infile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.tiff", help='输入shapefile文件')
# parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.png", help='输出geojson文件')
parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\切片结果", help='输出geojson文件')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--filemode',
action='store_const',
const='filemode',
dest='mode',
help='文件模式'
)
group.add_argument(
'--slicemode',
action='store_const',
const='slicemode',
dest='mode',
help='切片模式'
)
parser.set_defaults(mode='slicemode')
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--Linear',
action='store_const',
const='Linear',
dest='method',
help='线性拉伸'
)
group.add_argument(
'--Linear1prec',
action='store_const',
const='Linear1',
dest='method',
help='1%线性拉伸'
)
group.add_argument(
'--Linear2prec',
action='store_const',
const='Linear2',
dest='method',
help='2%线性拉伸'
)
group.add_argument(
'--Linear5prec',
action='store_const',
const='Linear5',
dest='method',
help='5%线性拉伸'
)
group.add_argument(
'--SquareRoot',
action='store_const',
const='SquareRoot',
dest='method',
help='平方根拉伸'
)
parser.set_defaults(method='SquareRoot')
args = parser.parse_args()
return args
if __name__ == '__main__':
try:
parser = getParams()
intiffPath=parser.infile
modestr=parser.mode
methodstr = parser.method
if modestr == "filemode":
outbinPath = parser.outfile
print('infile=', intiffPath)
print('outfile=', outbinPath)
print('method=', methodstr)
stretchProcess(intiffPath, outbinPath, methodstr)
elif modestr == "slicemode":
outfolder = parser.outfile
print('infile=', intiffPath)
print('outfolder=', outfolder)
print('method=', methodstr)
stretchSliceProcess(intiffPath, outfolder, methodstr)
pass
else:
print("模式错误")
exit(2)
except Exception as e:
print(e)
exit(3)