microproduct-l-sar/vegetationHeight-L-SAR/geocoding.py

233 lines
9.3 KiB
Python

##########################
# 地理编码
##########################
from osgeo import osr
from osgeo import gdal
import numpy as np
import os
import math
from sklearn.neighbors import KDTree
from xml.etree import ElementTree as ET
from tool.algorithm.image.ImageHandle import ImageHandler
# os.environ['PROJ_LIB'] = os.getcwd()
def get_lambel_sentinel(isce_work_path):
refrence_path=os.path.join(isce_work_path,"reference")
IW_path=None
for temp_file in os.listdir(refrence_path):
if temp_file.find("IW")==0 and temp_file.find(".xml")>0:
IW_path=os.path.join(refrence_path,temp_file)
# 解析波长
root=ET.parse(IW_path).getroot()
for node in root.iter():
if 'name' in node.attrib and node.attrib['name']=="radarwavelength":
for n in node.iter():
if n.tag=="value":
return float(n.text.replace("\n",""))
pass
pass
def get_long_lat_path(isce_work_path):
long_path=os.path.join(isce_work_path,"merged","geom_reference","lon.rdr.vrt")
lat_path=os.path.join(isce_work_path,"merged","geom_reference","lat.rdr.vrt")
# long_path = os.path.join(isce_work_path, "geom_reference", "lon.rdr.vrt")
# lat_path = os.path.join(isce_work_path, "geom_reference", "lat.rdr.vrt")
return [long_path,lat_path]
def get_hgt_path(isce_work_path):
# return os.path.join(isce_work_path, "geom_reference","hgt.rdr.vrt")
return os.path.join(isce_work_path, "merged", "geom_reference","hgt.rdr.vrt")
def get_Los_path(isce_work_path):
# return os.path.join(isce_work_path, "geom_reference","los.rdr.vrt")
return os.path.join(isce_work_path, "merged", "geom_reference","los.rdr.vrt")
def get_filt_fine_unw_path(isce_work_path):
flit_fine_unw_path_ls=[]
interferograms_path=os.path.join(isce_work_path,"Igrams")
for rootdir, dirs, files in os.walk(interferograms_path):
for filename in files:
# if filename=="filt_fine.unw.vrt":
if filename.endswith(".unw.vrt"):
flit_fine_unw_path_ls.append(os.path.join(rootdir, filename))
return flit_fine_unw_path_ls
def info_VRT(vrt_path):
vrt_data=gdal.Open(vrt_path,gdal.GA_ReadOnly)
print("width:\t",vrt_data.RasterXSize)
print("height:\t",vrt_data.RasterYSize)
print("bands numbers:\t",vrt_data.RasterCount)
for i in range(vrt_data.RasterCount):
data=vrt_data.ReadAsArray(0,0,vrt_data.RasterXSize,vrt_data.RasterYSize)
print("band {} ,dtype:{},\tmin:{}\tmax:\t{}".format(i+1,data.dtype,np.min(data),np.max(data)))
return vrt_data
def vrt2tiff(vrt_path,tiff_path,band_idx=0):
driver = gdal.GetDriverByName('GTiff') # 数据类型必须有,因为要计算需要多大内存空间
vrt_data=info_VRT(vrt_path)
tiff_dataset=driver.Create(tiff_path,vrt_data.RasterXSize, vrt_data.RasterYSize, 1, gdal.GDT_Float32)
if vrt_data.RasterCount>1:
data=vrt_data.ReadAsArray(0,0,vrt_data.RasterXSize,vrt_data.RasterYSize)[band_idx,:,:]
tiff_dataset.GetRasterBand(1).WriteArray(data)
else:
data=vrt_data.ReadAsArray(0,0,vrt_data.RasterXSize,vrt_data.RasterYSize)
tiff_dataset.GetRasterBand(1).WriteArray(data)
del tiff_dataset
def read_tiff_dataset(tiff_data,band_idx=0):
dataset=gdal.Open(tiff_data,gdal.GA_ReadOnly)
if dataset.RasterCount>1:
data=dataset.ReadAsArray(0,0,dataset.RasterXSize,dataset.RasterYSize)[band_idx,:,:]
else:
data=dataset.ReadAsArray(0,0,dataset.RasterXSize,dataset.RasterYSize)
del dataset
return data
def saveTiff(target_data_path,xsize,ysize,gt,srs,target_arr):
driver = gdal.GetDriverByName('GTiff') # 数据类型必须有,因为要计算需要多大内存空间
tiff_geo=driver.Create(target_data_path, int(xsize)+1, int(ysize)+1, 1, gdal.GDT_Float32)
tiff_geo.GetRasterBand(1).WriteArray(target_arr)
tiff_geo.SetGeoTransform(gt)
tiff_geo.SetProjection(srs.ExportToWkt())
del tiff_geo
# 插值地理编码方法
def geoCoding(tree,X_min,X_max,Y_min,Y_max,block_size,value_data,target_arr):
for i in range(X_min,X_max+block_size,block_size):
for j in range(Y_min,Y_max+block_size,block_size):
end_i=i+block_size if i+block_size<X_max else X_max
end_j=j+block_size if j+block_size<Y_max else Y_max
X_ids=list(range(i,end_i+1))
Y_ids=list(range(j,end_j+1))
[X_ids,Y_ids]=np.meshgrid(X_ids,Y_ids)
XY_query=np.concatenate([X_ids.reshape(-1,1),Y_ids.reshape(-1,1)],axis=1)
if XY_query.shape[0]==0:
continue
dist, ind = tree.query(XY_query, k=4)
# 处理距离为0
dist_mask=dist==0
dist_mask=np.sum(dist_mask,1)
idx=np.where(dist_mask>=1)
ind0=ind[idx,0]
X_ids=XY_query[idx,0] # x
Y_ids=XY_query[idx,1] # y
target_arr[Y_ids,X_ids]=value_data[ind0]
# 处理距离非0 这部分插值重写
idx=np.where(dist_mask==0)
ind=ind[idx,:][0,:,:]
dist=dist[idx,:][0,:,:]
X_ids=XY_query[idx,0] # x
Y_ids=XY_query[idx,1] # y
XY_query=np.concatenate([X_ids.reshape(-1,1),Y_ids.reshape(-1,1)],axis=1)
dist_mask=dist[:,3]<=50 ####################################################################################
idx=np.where(dist_mask==1)
ind=ind[idx,:][0,:,:]
dist=dist[idx,:][0,:,:]
X_ids=XY_query[idx,0] # x
Y_ids=XY_query[idx,1] # y
XY_query=np.concatenate([X_ids.reshape(-1,1),Y_ids.reshape(-1,1)],axis=1)
w=1/dist
w=w/np.sum(w,1).reshape(-1,1) # 权重
target_arr[XY_query[:,1],XY_query[:,0]]=np.sum(w*value_data[ind],1)
print(end_i/X_max,"...",end="")
return target_arr
def get_Dem(isce_work_path,temp_work_path,pack_path,product_name, array):
# lamda=get_lambel_sentinel(isce_work_path)
# filt_topophase_unw_path=get_filt_fine_unw_path(isce_work_path)[0]
# unw_tiff_path=os.path.join(temp_work_path,"unw.tiff")
# vrt2tiff(filt_topophase_unw_path,unw_tiff_path,1)
[lon_path,lat_path]=get_long_lat_path(isce_work_path)
lon_tiff_path=os.path.join(temp_work_path,"lon.tiff")
vrt2tiff(lon_path,lon_tiff_path)
lat_tiff_path=os.path.join(temp_work_path,"lat.tiff")
vrt2tiff(lat_path,lat_tiff_path)
hgt_path=get_hgt_path(isce_work_path)
hgt_tiff_path=os.path.join(temp_work_path,"hgt.tiff")
vrt2tiff(hgt_path,hgt_tiff_path)
los_path=get_Los_path(isce_work_path)
los_tiff_path=os.path.join(temp_work_path,"los_uwm.tiff")
vrt2tiff(los_path,los_tiff_path,0)
los_data=read_tiff_dataset(los_tiff_path,band_idx=0)
lon_data=read_tiff_dataset(lon_tiff_path,band_idx=0)
lon = lon_data[0, :]
lon_min = np.min(lon)
lon_max = np.max(lon)
lat_data=read_tiff_dataset(lat_tiff_path,band_idx=0)
lat = lat_data[:, lat_data.shape[1]-1]
lat_min = np.min(lat)
lat_max = np.max(lat)
los_data=los_data.reshape(-1)
lon_data=lon_data.reshape(-1)
lat_data=lat_data.reshape(-1)
array = array.reshape(-1)
mask_idx=np.where(los_data>0)
lon_data=lon_data[mask_idx]
lat_data=lat_data[mask_idx]
array = array[mask_idx]
# print(unw_data.shape,np.min(unw_data),np.max(unw_data))
del mask_idx,los_data
box = str(lat_min) + ';' + str(lat_max) + ';' + str(lon_min) + ';' + str(lon_max)
# 形变值计算
hgt_data=array # 叠加isce高程
# hgt_data=def_data # 叠加isce高程
hgt_tiff_path = os.path.join(temp_work_path, "hgt_data.tiff")
# ImageHandler.write_img(hgt_tiff_path, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], hgt_data)
#
gt=[np.min(lon_data), 0.0002777777777777778, 0.0, np.max(lat_data), 0.0, -0.0002777777777777778]
xsize=((np.max(lon_data)-np.min(lon_data))/0.0002777777777777778)+1
ysize=((np.min(lat_data)-np.max(lat_data))/-0.0002777777777777778)+1
# 获取地理坐标系统信息,用于选取需要的地理坐标系统
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # 定义输出的坐标系为"WGS 84"
inv_gt=gdal.InvGeoTransform(gt)
print(xsize,ysize)
def_arr=np.zeros((int(ysize)+1,int(xsize)+1))-9999
dem_arr=np.zeros((int(ysize)+1,int(xsize)+1))-9999
X=inv_gt[0]+lon_data*inv_gt[1]+inv_gt[2]*lat_data
Y=inv_gt[3]+lon_data*inv_gt[4]+inv_gt[5]*lat_data
XY=np.zeros((X.shape[0],2))
XY[:,0]=X
XY[:,1]=Y
tree = KDTree(XY, leaf_size=2)
X_min=math.ceil(np.min(X))
X_max=math.floor(np.max(X))
Y_min=math.ceil(np.min(Y))
Y_max=math.floor(np.max(Y))
block_size=1000
dem_arr=geoCoding(tree,X_min,X_max,Y_min,Y_max,block_size,hgt_data,dem_arr)
if not os.path.exists(pack_path):
os.mkdir(pack_path)
dem_target_data_path=product_name
saveTiff(dem_target_data_path,xsize,ysize,gt,srs,dem_arr)
# return hgt_tiff_path, box
return dem_target_data_path
if __name__ == '__main__':
isce_work_path = r"D:\micro\microproduct_depdence\GF3-Deformation\isce_work"
temp_work_path = r"D:\micro\microproduct_depdence\GF3-Deformation\test"
out_work_path = r"D:\micro\microproduct_depdence\GF3-Deformation\test"
product = r'D:\micro\microproduct_depdence\GF3-Deformation\test\dem.tiff'
get_Dem(isce_work_path, temp_work_path, out_work_path, product)
pass