microproduct/dem-C-SAR/geocoding.py

300 lines
12 KiB
Python

##########################
# 地理编码
##########################
from osgeo import osr
from osgeo import gdal
import numpy as np
import os
import math
from sklearn.neighbors import KDTree
from xml.etree import ElementTree as ET
from tool.algorithm.image.ImageHandle import ImageHandler
os.environ['PROJ_LIB'] = os.getcwd()
def get_lambel_sentinel(isce_work_path):
refrence_path=os.path.join(isce_work_path,"reference")
IW_path=None
for temp_file in os.listdir(refrence_path):
if temp_file.find("IW")==0 and temp_file.find(".xml")>0:
IW_path=os.path.join(refrence_path,temp_file)
# 解析波长
root=ET.parse(IW_path).getroot()
for node in root.iter():
if 'name' in node.attrib and node.attrib['name']=="radarwavelength":
for n in node.iter():
if n.tag=="value":
return float(n.text.replace("\n",""))
pass
pass
def get_long_lat_path(isce_work_path):
# long_path=os.path.join(isce_work_path,"merged","geom_reference","lon.rdr.vrt")
# lat_path=os.path.join(isce_work_path,"merged","geom_reference","lat.rdr.vrt")
long_path = os.path.join(isce_work_path, "geom_reference", "lon.rdr.vrt")
lat_path = os.path.join(isce_work_path, "geom_reference", "lat.rdr.vrt")
return [long_path,lat_path]
def get_hgt_path(isce_work_path):
return os.path.join(isce_work_path, "geom_reference","hgt.rdr.vrt")
# return os.path.join(isce_work_path, "merged", "geom_reference","hgt.rdr.vrt")
def get_Los_path(isce_work_path):
return os.path.join(isce_work_path, "geom_reference","los.rdr.vrt")
# return os.path.join(isce_work_path, "merged", "geom_reference","los.rdr.vrt")
def get_filt_fine_unw_path(isce_work_path):
flit_fine_unw_path_ls=[]
interferograms_path=os.path.join(isce_work_path,"Igrams")
for rootdir, dirs, files in os.walk(interferograms_path):
for filename in files:
# if filename=="filt_fine.unw.vrt":
if filename.endswith(".unw.vrt"):
flit_fine_unw_path_ls.append(os.path.join(rootdir, filename))
return flit_fine_unw_path_ls
def get_filt_fine_cor_path(isce_work_path):
flit_fine_cor_path_ls=[]
interferograms_path=os.path.join(isce_work_path,"Igrams")
for rootdir, dirs, files in os.walk(interferograms_path):
for filename in files:
# if filename=="filt_fine.unw.vrt":
if filename.endswith(".cor.vrt"):
flit_fine_cor_path_ls.append(os.path.join(rootdir, filename))
return flit_fine_cor_path_ls
def info_VRT(vrt_path):
vrt_data=gdal.Open(vrt_path,gdal.GA_ReadOnly)
print("width:\t",vrt_data.RasterXSize)
print("height:\t",vrt_data.RasterYSize)
print("bands numbers:\t",vrt_data.RasterCount)
for i in range(vrt_data.RasterCount):
data=vrt_data.ReadAsArray(0,0,vrt_data.RasterXSize,vrt_data.RasterYSize)
print("band {} ,dtype:{},\tmin:{}\tmax:\t{}".format(i+1,data.dtype,np.min(data),np.max(data)))
return vrt_data
def vrt2tiff(vrt_path,tiff_path,band_idx=0):
driver = gdal.GetDriverByName('GTiff') # 数据类型必须有,因为要计算需要多大内存空间
vrt_data=info_VRT(vrt_path)
tiff_dataset=driver.Create(tiff_path,vrt_data.RasterXSize, vrt_data.RasterYSize, 1, gdal.GDT_Float32)
if vrt_data.RasterCount>1:
data=vrt_data.ReadAsArray(0,0,vrt_data.RasterXSize,vrt_data.RasterYSize)[band_idx,:,:]
tiff_dataset.GetRasterBand(1).WriteArray(data)
else:
data=vrt_data.ReadAsArray(0,0,vrt_data.RasterXSize,vrt_data.RasterYSize)
tiff_dataset.GetRasterBand(1).WriteArray(data)
del tiff_dataset
def read_tiff_dataset(tiff_data,band_idx=0):
dataset=gdal.Open(tiff_data,gdal.GA_ReadOnly)
if dataset.RasterCount>1:
data=dataset.ReadAsArray(0,0,dataset.RasterXSize,dataset.RasterYSize)[band_idx,:,:]
else:
data=dataset.ReadAsArray(0,0,dataset.RasterXSize,dataset.RasterYSize)
del dataset
return data
def saveTiff(target_data_path,xsize,ysize,gt,srs,target_arr):
driver = gdal.GetDriverByName('GTiff') # 数据类型必须有,因为要计算需要多大内存空间
tiff_geo=driver.Create(target_data_path, int(xsize)+1, int(ysize)+1, 1, gdal.GDT_Float32)
tiff_geo.GetRasterBand(1).WriteArray(target_arr)
tiff_geo.GetRasterBand(1).SetNoDataValue(-9999)
tiff_geo.SetGeoTransform(gt)
tiff_geo.SetProjection(srs.ExportToWkt())
del tiff_geo
# 插值地理编码方法
def geoCoding(tree,X_min,X_max,Y_min,Y_max,block_size,value_data,target_arr):
for i in range(X_min,X_max+block_size,block_size):
for j in range(Y_min,Y_max+block_size,block_size):
end_i=i+block_size if i+block_size<X_max else X_max
end_j=j+block_size if j+block_size<Y_max else Y_max
X_ids=list(range(i,end_i+1))
Y_ids=list(range(j,end_j+1))
[X_ids,Y_ids]=np.meshgrid(X_ids,Y_ids)
XY_query=np.concatenate([X_ids.reshape(-1,1),Y_ids.reshape(-1,1)],axis=1)
if XY_query.shape[0]==0:
continue
dist, ind = tree.query(XY_query, k=4)
# 处理距离为0
dist_mask=dist==0
dist_mask=np.sum(dist_mask,1)
idx=np.where(dist_mask>=1)
ind0=ind[idx,0]
X_ids=XY_query[idx,0] # x
Y_ids=XY_query[idx,1] # y
target_arr[Y_ids,X_ids]=value_data[ind0]
# 处理距离非0 这部分插值重写
idx=np.where(dist_mask==0)
ind=ind[idx,:][0,:,:]
dist=dist[idx,:][0,:,:]
X_ids=XY_query[idx,0] # x
Y_ids=XY_query[idx,1] # y
XY_query=np.concatenate([X_ids.reshape(-1,1),Y_ids.reshape(-1,1)],axis=1)
dist_mask=dist[:,3]<=2 ####################################################################################
idx=np.where(dist_mask==1)
ind=ind[idx,:][0,:,:]
dist=dist[idx,:][0,:,:]
X_ids=XY_query[idx,0] # x
Y_ids=XY_query[idx,1] # y
XY_query=np.concatenate([X_ids.reshape(-1,1),Y_ids.reshape(-1,1)],axis=1)
w=1/dist
w=w/np.sum(w,1).reshape(-1,1) # 权重
target_arr[XY_query[:,1],XY_query[:,0]]=np.sum(w*value_data[ind],1)
print(end_i/X_max,"...",end="")
return target_arr
def detrend_2d(unw_filename,cor_filename, out_file):
unwImg = ImageHandler.get_data(unw_filename)
# unwImg = unwImgt[1,:,:]
corImg = ImageHandler.get_data(cor_filename)
height = corImg.shape[0]
width = corImg.shape[1]
lines_intv = int(np.floor(height * 0.005))
width_intv = int(np.floor(width * 0.005))
x = np.arange(1, width, width_intv)
y = np.arange(1, height, lines_intv)
pointX = []
pointY = []
pointZ = []
for i in y:
for j in x:
if corImg[i,j] < 0.2:
continue
else:
pointX.append(j)
pointY.append(i)
pointZ.append(unwImg[i,j])
int_xy = np.multiply(np.array(pointX), np.array(pointY))
int_xx = np.multiply(np.array(pointX), np.array(pointX))
int_yy = np.multiply(np.array(pointY), np.array(pointY))
design_matrix = np.column_stack((np.array(pointX), np.array(pointY), int_xy, int_xx, int_yy))
X = np.linalg.lstsq(design_matrix, pointZ, rcond=None)[0]
a = X[0] # 系数1
b = X[1] # 系数2
c = X[2] # 系数3
d = X[3] # 系数4
e = X[4] # 系数5
dtd_unw = np.zeros((height, width), dtype=float)
for ii in range(width):
for jj in range(height):
str = a * (ii+1) + b * (jj+1) + c * (ii+1) * (jj+1) + d * (ii+1) * (ii+1) + e * (jj+1) * (jj+1)
# dtd_unw[ii, jj] = unwImg[ii, jj] - str
dtd_unw[jj, ii] = unwImg[jj, ii] - str
dtd_unw[np.where(corImg==0)] = 0
ImageHandler.write_img(out_file, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], dtd_unw)
return dtd_unw
def get_Dem(isce_work_path, temp_work_path, pack_path, product_name, lamda):
# lamda=get_lambel_sentinel(isce_work_path)
filt_topophase_unw_path=get_filt_fine_unw_path(isce_work_path)[0]
unw_tiff_path=os.path.join(temp_work_path,"unw.tiff")
vrt2tiff(filt_topophase_unw_path,unw_tiff_path,1)
filt_topophase_cor_path = get_filt_fine_cor_path(isce_work_path)[0]
[lon_path,lat_path]=get_long_lat_path(isce_work_path)
lon_tiff_path=os.path.join(temp_work_path,"lon.tiff")
vrt2tiff(lon_path,lon_tiff_path)
lat_tiff_path=os.path.join(temp_work_path,"lat.tiff")
vrt2tiff(lat_path,lat_tiff_path)
hgt_path=get_hgt_path(isce_work_path)
hgt_tiff_path=os.path.join(temp_work_path,"hgt.tiff")
vrt2tiff(hgt_path,hgt_tiff_path)
los_path=get_Los_path(isce_work_path)
los_tiff_path=os.path.join(temp_work_path,"los_uwm.tiff")
vrt2tiff(los_path,los_tiff_path,0)
los_data=read_tiff_dataset(los_tiff_path,band_idx=0)
lon_data=read_tiff_dataset(lon_tiff_path,band_idx=0)
lon = lon_data[0, :]
lon_min = np.min(lon_data)
lon_max = np.max(lon_data)
lat_data=read_tiff_dataset(lat_tiff_path,band_idx=0)
lat = lat_data[:, lat_data.shape[1]-1]
lat_min = np.min(lat_data)
lat_max = np.max(lat_data)
out_detrend_path = os.path.join(temp_work_path, "detrend_unw.tif")
unw_data = detrend_2d(unw_tiff_path, filt_topophase_cor_path, out_detrend_path)
# unw_data=read_tiff_dataset(unw_tiff_path,band_idx=0)
hgt_data=read_tiff_dataset(hgt_tiff_path,band_idx=0)
los_data=los_data.reshape(-1)
lon_data=lon_data.reshape(-1)
lat_data=lat_data.reshape(-1)
unw_data=unw_data.reshape(-1)
hgt_data=hgt_data.reshape(-1)
mask_idx=np.where(los_data>0)
lon_data=lon_data[mask_idx]
lat_data=lat_data[mask_idx]
unw_data=unw_data[mask_idx]
hgt_data=hgt_data[mask_idx]
print(unw_data.shape,np.min(unw_data),np.max(unw_data))
del mask_idx,los_data
box = str(lat_min) + ';' + str(lat_max) + ';' + str(lon_min) + ';' + str(lon_max)
# 形变值计算
def_data=-1*unw_data*lamda/(4*math.pi)
hgt_data= hgt_data + def_data # 叠加isce高程
# hgt_data=def_data # 叠加isce高程
hgt_tiff_path = os.path.join(temp_work_path, "hgt_data.tiff")
# ImageHandler.write_img(hgt_tiff_path, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], hgt_data)
#
gt=[np.min(lon_data), 0.0002777777777777778, 0.0, np.max(lat_data), 0.0, -0.0002777777777777778]
xsize=((np.max(lon_data)-np.min(lon_data))/0.0002777777777777778)+1
ysize=((np.min(lat_data)-np.max(lat_data))/-0.0002777777777777778)+1
# 获取地理坐标系统信息,用于选取需要的地理坐标系统
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # 定义输出的坐标系为"WGS 84"
inv_gt=gdal.InvGeoTransform(gt)
print(xsize,ysize)
# def_arr=np.zeros((int(ysize)+1,int(xsize)+1))-9999
dem_arr=np.zeros((int(ysize)+1,int(xsize)+1))-9999
X=inv_gt[0]+lon_data*inv_gt[1]+inv_gt[2]*lat_data
Y=inv_gt[3]+lon_data*inv_gt[4]+inv_gt[5]*lat_data
XY=np.zeros((X.shape[0],2))
XY[:,0]=X
XY[:,1]=Y
tree = KDTree(XY, leaf_size=2)
X_min=math.ceil(np.min(X))
X_max=math.floor(np.max(X))
Y_min=math.ceil(np.min(Y))
Y_max=math.floor(np.max(Y))
block_size=3000
dem_arr=geoCoding(tree,X_min,X_max,Y_min,Y_max,block_size,hgt_data,dem_arr)
if not os.path.exists(pack_path):
os.mkdir(pack_path)
dem_target_data_path=product_name
saveTiff(dem_target_data_path,xsize,ysize,gt,srs,dem_arr)
# return hgt_tiff_path, box
return dem_target_data_path
if __name__ == '__main__':
isce_work_path = r"D:\micro\WorkSpace\Dem\Temporary\processing\isce_workspace"
temp_work_path = r"D:\micro\WorkSpace\Dem\Temporary\processing\isce_workspace\test"
out_work_path = r"D:\micro\WorkSpace\Dem\Temporary\processing\isce_workspace\test"
product = r'D:\micro\WorkSpace\Dem\Temporary\processing\isce_workspace\test\dem.tiff'
get_Dem(isce_work_path, temp_work_path, out_work_path, product, 0.055517)
pass