microproduct/backScattering/OrthoAuxData.py

416 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 一米正射辅助数据处理类
import time
import math
import numpy as np
from osgeo import gdal
from xml.etree.ElementTree import ElementTree
from scipy.optimize import leastsq
class OrthoAuxData:
def __init__(self):
pass
@staticmethod
def time_stamp(tm):
list = tm.split(':')
sec = math.ceil(float(list[2]))
tm1 = list[0] + ':' + list[1] + ':' + str(sec)
tmArr = time.strptime(tm1, "%Y-%m-%d %H:%M:%S")
# tmArr = time.strptime(tm1, "%Y-%m-%d %H:%M:%S.%f")
ts = float(time.mktime(tmArr)) # 转换为时间戳
return ts
@staticmethod
def read_meta(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
T = []
Xs = []
Ys = []
Zs = []
Vsx = []
Vsy = []
Vsz = []
GPS_data = root.find('GPS')
for child in GPS_data:
Xs.append(float(child.find('xPosition').text))
Ys.append(float(child.find('yPosition').text))
Zs.append(float(child.find('zPosition').text))
Vsx.append(float(child.find('xVelocity').text))
Vsy.append(float(child.find('yVelocity').text))
Vsz.append(float(child.find('zVelocity').text))
tm = child.find('TimeStamp').text
ts = OrthoAuxData.time_stamp(tm)
T.append(ts)
meta_data = [Xs, Ys, Zs, Vsx, Vsy, Vsz]
return T, meta_data
@staticmethod
def read_control_points(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
imageinfo = root.find('imageinfo')
center = imageinfo.find('center')
corner = imageinfo.find('corner')
ctrl_pts = [[] for i in range(2)]
ctrl_pts[0].append(float(center.find('longitude').text))
ctrl_pts[1].append(float(center.find('latitude').text))
for child in corner:
ctrl_pts[0].append(float(child.find('longitude').text))
ctrl_pts[1].append(float(child.find('latitude').text))
return ctrl_pts
@staticmethod
def read_dem(dem_resampled_path, flag=1):
in_ds = gdal.Open(dem_resampled_path)
gt = list(in_ds.GetGeoTransform())
bands_num = in_ds.RasterCount
x_size = in_ds.RasterXSize
y_size = in_ds.RasterYSize
pstn_arr = np.zeros([y_size, x_size, 3], dtype=np.float)
for i in range(1, bands_num + 1):
data = in_ds.GetRasterBand(i).ReadAsArray(0, 0, x_size, y_size)
for y in range(y_size):
for x in range(x_size):
longitude = gt[0] + x * gt[1]
latitude = gt[3] + y * gt[5]
altitude = data[y, x]
if flag == 1:
pstn = OrthoAuxData.LLA2XYZ(longitude, latitude, altitude)
else:
pstn = [longitude, latitude, altitude]
pstn_arr[y, x, 0] = pstn[0]
pstn_arr[y, x, 1] = pstn[1]
pstn_arr[y, x, 2] = pstn[2]
del in_ds, data
return pstn_arr
@staticmethod
def read_demM(dem_resampled_path, part_cnt, r_cnt, c_cnt, flag=1):
in_ds = gdal.Open(dem_resampled_path)
gt = list(in_ds.GetGeoTransform())
bands_num = in_ds.RasterCount
x_size = in_ds.RasterXSize // part_cnt
y_size = in_ds.RasterYSize // part_cnt
x = [[i] * y_size for i in range(x_size)]
y = [[i] * x_size for i in range(y_size)]
x = np.array(x)
x = x.T
y = np.array(y)
x_off = c_cnt * x_size
y_off = r_cnt * y_size
gt[0] = gt[0] + c_cnt * x_size * gt[1]
gt[3] = gt[3] + r_cnt * y_size * gt[5]
for i in range(1, bands_num + 1):
data = in_ds.GetRasterBand(i).ReadAsArray(x_off, y_off, x_size, y_size)
altitude = data / 255 * 1024
longitude = gt[0] + x * gt[1]
latitude = gt[3] + y * gt[5]
if flag == 1:
pstn = OrthoAuxData.LLA2XYZM(longitude, latitude, altitude)
else:
pstn = [longitude, latitude, altitude]
del in_ds, data
return pstn
@staticmethod
def read_dem_row(dem_resampled_path, p, flag=1):
in_ds = gdal.Open(dem_resampled_path)
gt = list(in_ds.GetGeoTransform())
bands_num = in_ds.RasterCount
x_size = in_ds.RasterXSize
y_size = in_ds.RasterYSize
x = [[i] for i in range(x_size)]
x = np.array(x)
x = x.T
y = np.ones((1, x_size)) * p
x_off = 0
y_off = p
for i in range(1, bands_num + 1):
data = in_ds.GetRasterBand(i).ReadAsArray(x_off, y_off, x_size, 1)
altitude = data
longitude = gt[0] + x * gt[1]
latitude = gt[3] + y * gt[5]
if flag == 1:
pstn = OrthoAuxData.LLA2XYZM(longitude, latitude, altitude)
else:
pstn = [longitude, latitude, altitude]
del in_ds, data
return pstn
@staticmethod
def orbit_fitting(time_array, meta_data):
# 最小二乘法求解轨道参数
T0 = (time_array[0] + time_array[len(time_array)-1]) / 2
t = []
for i in range(len(time_array)):
t.append(time_array[i]-T0)
def func(p, x):
w3, w2, w1, w0 = p
return w3*x**3 + w2*x**2 + w1*x + w0
def error(p, x, y):
return func(p, x) - y
orbital_paras = []
for j in range(len(meta_data)):
p0 = [1, 2, 3, 4]
x = np.array(t)
y = np.array(meta_data[j])
Para = leastsq(error, p0, args=(x, y))
orbital_paras.append(Para[0])
print(Para[0], Para[1])
return orbital_paras, T0
@staticmethod
def get_PRF(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
sensor = root.find('sensor')
waveParams = sensor.find('waveParams')
PRF = float(waveParams.find('wave').find('prf').text)
return PRF
@staticmethod
def get_delta_R(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
sensor = root.find('sensor')
pulseWidth = float(sensor.find('waveParams').find('wave').find('pulseWidth').text)
bandWidth = float(sensor.find('waveParams').find('wave').find('bandWidth').text)
c = 299792458
delta_R = c / (1000000 * 2 * bandWidth)
return delta_R
@staticmethod
def get_doppler_rate_coef(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
processinfo = root.find('processinfo')
doppler = processinfo.find('DopplerRateValuesCoefficients')
t0 = float(processinfo.find('DopplerParametersReferenceTime').text)
r0 = float(doppler.find('r0').text)
r1 = float(doppler.find('r1').text)
r2 = float(doppler.find('r2').text)
r3 = float(doppler.find('r3').text)
r4 = float(doppler.find('r4').text)
return t0, np.array([r0, r1, r2, r3, r4]).reshape(5, 1)
@staticmethod
def get_doppler_center_coef(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
processinfo = root.find('processinfo')
doppler = processinfo.find('DopplerCentroidCoefficients')
b0 = float(doppler.find('d0').text)
b1 = float(doppler.find('d1').text)
b2 = float(doppler.find('d2').text)
return b0, b1, b2
@staticmethod
def get_lamda(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
sensor = root.find('sensor')
λ = float(sensor.find('lamda').text)
return λ
@staticmethod
def get_t0(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
imageinfo = root.find('imageinfo')
tm = imageinfo.find('imagingTime').find('start').text
t0 = OrthoAuxData.time_stamp(tm)
return t0
@staticmethod
def get_start_and_end_time(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
imageinfo = root.find('imageinfo')
tm0 = imageinfo.find('imagingTime').find('start').text
tm1 = imageinfo.find('imagingTime').find('end').text
starttime = OrthoAuxData.time_stamp(tm0)
endtime = OrthoAuxData.time_stamp(tm1)
return starttime, endtime
@staticmethod
def get_width_and_height(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
imageinfo = root.find('imageinfo')
width = int(imageinfo.find('width').text)
height = int(imageinfo.find('height').text)
return width, height
@staticmethod
def get_R0(meta_file_path):
tree = ElementTree()
tree.parse(meta_file_path)
root = tree.getroot()
imageinfo = root.find('imageinfo')
R0 = float(imageinfo.find('nearRange').text)
return R0
@staticmethod
def get_h():
h = 6.6
return h
@staticmethod
def LLA2XYZ(longitude, latitude, altitude):
'''
WGS-84坐标系下经纬度坐标转空间直角坐标
'''
# 经纬度余弦值
cosLat = math.cos(latitude * math.pi / 180)
sinLat = math.sin(latitude * math.pi / 180)
cosLon = math.cos(longitude * math.pi / 180)
sinLon = math.sin(longitude * math.pi / 180)
# WGS84坐标系参数
rad = 6378137.0 #地球赤道平均半径
f = 1.0/298.257224 #WGS84椭球扁率
C = 1.0/math.sqrt(cosLat*cosLat + (1-f)*(1-f)*sinLat*sinLat)
S = (1-f)*(1-f)*C
h = altitude
# 计算XYZ坐标
X = (rad * C + h) * cosLat * cosLon
Y = (rad * C + h) * cosLat * sinLon
Z = (rad * S + h) * sinLat
# return np.array([X, Y, Z]).reshape(1,3)
return [X, Y, Z]
@staticmethod
def LLA2XYZM(longitude, latitude, altitude):
# 经纬度余弦值
cosLat = np.cos(latitude * math.pi / 180).reshape(-1,1)
sinLat = np.sin(latitude * math.pi / 180).reshape(-1,1)
cosLon = np.cos(longitude * math.pi / 180).reshape(-1,1)
sinLon = np.sin(longitude * math.pi / 180).reshape(-1,1)
# WGS84坐标系参数
rad = 6378137.0 #地球赤道平均半径
f = 1.0/298.257224 #WGS84椭球扁率
C = 1.0/(np.sqrt(cosLat*cosLat + (1-f)*(1-f)*sinLat*sinLat)).reshape(-1,1)
S = (1-f)*(1-f)*C
h = altitude.reshape(-1,1)
# 计算XYZ坐标
X = (rad * C + h) * cosLat * cosLon
Y = (rad * C + h) * cosLat * sinLon
Z = (rad * S + h) * sinLat
return [X, Y, Z]
@staticmethod
def XYZ2LLA(X, Y, Z):
''' 大地坐标系转经纬度
适用于WGS84坐标系
args:
x,y,z
return:
lat,long,altitude
'''
# WGS84坐标系的参数
a = 6378137.0 # 椭球长半轴
b = 6356752.314245 # 椭球短半轴
ea = np.sqrt((a ** 2 - b ** 2) / a ** 2)
eb = np.sqrt((a ** 2 - b ** 2) / b ** 2)
p = np.sqrt(X ** 2 + Y ** 2)
theta = np.arctan2(Z * a, p * b)
# 计算经纬度及海拔
longitude = np.arctan2(Y, X)
latitude = np.arctan2(Z + eb ** 2 * b * np.sin(theta) ** 3, p - ea ** 2 * a * np.cos(theta) ** 3)
N = a / np.sqrt(1 - ea ** 2 * np.sin(latitude) ** 2)
altitude = p / np.cos(latitude) - N
# return np.array([np.degrees(latitude), np.degrees(longitude), altitude])
return [np.degrees(longitude), np.degrees(latitude), altitude]
@staticmethod
def XYZ2LLAM(X, Y, Z):
''' 大地坐标系转经纬度
适用于WGS84坐标系
args:
x,y,z
return:
lat,long,altitude
'''
# WGS84坐标系的参数
a = 6378137.0 # 椭球长半轴
b = 6356752.314245 # 椭球短半轴
ea = np.sqrt((a ** 2 - b ** 2) / a ** 2)
eb = np.sqrt((a ** 2 - b ** 2) / b ** 2)
p = np.sqrt(X ** 2 + Y ** 2)
theta = np.arctan2(Z * a, p * b)
# 计算经纬度及海拔
longitude = np.arctan2(Y, X)
latitude = np.arctan2(Z + eb ** 2 * b * np.sin(theta) ** 3, p - ea ** 2 * a * np.cos(theta) ** 3)
N = a / np.sqrt(1 - ea ** 2 * np.sin(latitude) ** 2)
altitude = p / np.cos(latitude) - N
# return np.array([np.degrees(latitude), np.degrees(longitude), altitude])
return [np.degrees(longitude), np.degrees(latitude), altitude]
@staticmethod
def world2Pixel(geoMatrix, x, y):
"""
使用GDAL库的geomatrix对象((gdal.GetGeoTransform()))计算地理坐标的像素位置
"""
ulx = geoMatrix[0]
uly = geoMatrix[3]
xDist = geoMatrix[1]
yDist = geoMatrix[5]
rtnX = geoMatrix[2]
rtnY = geoMatrix[4]
pixel = int((x - ulx) / xDist)
line = int((uly - y) / abs(yDist))
return pixel, line
@staticmethod
def sar_intensity_synthesis(in_sar_tif, out_sar_tif):
# 获取SLC格式SAR影像的相关信息
in_ds = gdal.Open(in_sar_tif)
bands_num = in_ds.RasterCount
rows = in_ds.RasterYSize
columns = in_ds.RasterXSize
proj = in_ds.GetProjection()
geotrans = in_ds.GetGeoTransform()
# 创建输出的SAR强度图
gtiff_driver = gdal.GetDriverByName('GTiff')
out_ds = gtiff_driver.Create(out_sar_tif, columns, rows, 1)
out_ds.SetProjection(proj)
out_ds.SetGeoTransform(geotrans)
# 输出SAR强度图
in_data1 = in_ds.GetRasterBand(1).ReadAsArray(0, 0, columns, rows)
in_data1 = in_data1/10
in_data1 = np.power(10, in_data1)
in_data2 = in_ds.GetRasterBand(2).ReadAsArray(0, 0, columns, rows)
in_data2 = in_data2 / 10
in_data2 = np.power(10, in_data2)
out_data = np.sqrt(in_data1**2 + in_data2**2)
out_ds.GetRasterBand(1).WriteArray(out_data)
del in_ds, out_ds