更新裁剪方法,通过地距映射表进行斜距影像裁剪

dev
tian jiax 2023-11-22 09:28:51 +08:00
parent 762078fc90
commit c6cbe20b87
3 changed files with 170 additions and 39 deletions

View File

@ -2,7 +2,7 @@ import os
cimport cython # 必须导入 cimport cython # 必须导入
import numpy as np##必须为c类型和python类型的数据都申明一个np import numpy as np##必须为c类型和python类型的数据都申明一个np
cimport numpy as np # 必须为c类型和python类型的数据都申明一个np cimport numpy as np # 必须为c类型和python类型的数据都申明一个np
from libc.math cimport pi from libc.math cimport pi,ceil,floor
from scipy.interpolate import griddata from scipy.interpolate import griddata
@ -86,17 +86,54 @@ cpdef np.ndarray[double,ndim=2] cut_L1A_img(np.ndarray[double,ndim=3] ori2geo_i
while i<height: while i<height:
j=0 j=0
while j<width: while j<width:
temp_p.x=ori2geo_img[0,i,j] temp_p.x=ori2geo_img[0,i,j] # temp_p
temp_p.y=ori2geo_img[1,i,j] temp_p.y=ori2geo_img[1,i,j] # temp_p
if rayCasting(temp_p,roi_list)==1: if rayCasting(temp_p,roi_list)==1:
mask[i,j]=1 mask[i,j]=1
else: else:
mask[i,j]=np.nan mask[i,j]=np.nan
j=j+1 j=j+1
i=i+1 i=i+1
return mask return mask
cpdef np.ndarray[double,ndim=2] gereratorMask(np.ndarray[double,ndim=1] rlist,np.ndarray[double,ndim=1] clist,np.ndarray[double,ndim=2] mask):
cdef int rcount=rlist.shape[0]
cdef int ccount=clist.shape[0]
cdef int count=rcount if rcount<ccount else ccount
cdef int i=0
cdef int j=0
cdef int temp_row=0
cdef int temp_col=0
cdef int height=mask.shape[0]
cdef int width=mask.shape[1]
while i<count:
# 1
temp_row=int(ceil(rlist[i]))
temp_col=int(ceil(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
# 2
temp_row=int(floor(rlist[i]))
temp_col=int(ceil(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
# 3
temp_row=int(ceil(rlist[i]))
temp_col=int(floor(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
# 4
temp_row=int(floor(rlist[i]))
temp_col=int(floor(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
i=i+1
return mask
cdef double distance_powe(Point p1,Point p2): cdef double distance_powe(Point p1,Point p2):

View File

@ -4,13 +4,14 @@ import numpy as np
import scipy import scipy
from scipy.interpolate import griddata, RegularGridInterpolator from scipy.interpolate import griddata, RegularGridInterpolator
import logging import logging
import pyresample as pr # import pyresample as pr
# 插值模块 # 插值模块
from pyresample.bilinear import NumpyBilinearResampler from pyresample.bilinear import NumpyBilinearResampler
from pyresample import geometry from pyresample import geometry
from pyresample.geometry import AreaDefinition from pyresample.geometry import AreaDefinition
from osgeo import osr from osgeo import osr
import os import os
import math
# os.environ['PROJ_LIB'] = r"D:\Anaconda\envs\micro\Lib\site-packages\osgeo\data\proj" # os.environ['PROJ_LIB'] = r"D:\Anaconda\envs\micro\Lib\site-packages\osgeo\data\proj"
@ -26,6 +27,7 @@ def griddata_geo(points, data, lon_grid, lat_grid, method, i, end_i):
grid_data = grid_data[:, :, 0] grid_data = grid_data[:, :, 0]
return [i, end_i, grid_data] return [i, end_i, grid_data]
def griddataBlock(start_x, len_x, start_y, len_y, grid_data_input, grid_x, grid_y, method): def griddataBlock(start_x, len_x, start_y, len_y, grid_data_input, grid_x, grid_y, method):
grid_x = grid_x.reshape(-1) grid_x = grid_x.reshape(-1)
grid_y = grid_y.reshape(-1) grid_y = grid_y.reshape(-1)
@ -82,10 +84,12 @@ class polyfit2d_U:
class TransImgL1A: class TransImgL1A:
def __init__(self, ori_sim_path, roi): def __init__(self, ori_sim_path, roi, l1a_height, l1a_width):
self._begin_r, self._begin_c, self._end_r, self._end_c = 0, 0, 0, 0 self._begin_r, self._begin_c, self._end_r, self._end_c = 0, 0, 0, 0
self.ori2geo_img = None self.ori2geo_img = None
self._mask = None self._mask = None
self.l1a_height = l1a_height
self.l1a_width = l1a_width
self._min_lon, self._max_lon, self._min_lat, self._max_lat = 0, 0, 0, 0 self._min_lon, self._max_lon, self._min_lat, self._max_lat = 0, 0, 0, 0
self.init_trans_para(ori_sim_path, roi) self.init_trans_para(ori_sim_path, roi)
@ -94,7 +98,6 @@ class TransImgL1A:
data = [(self._begin_r + row, self._begin_c + col) for (row, col) in zip(rowcol[0], rowcol[1])] data = [(self._begin_r + row, self._begin_c + col) for (row, col) in zip(rowcol[0], rowcol[1])]
return data return data
def get_lonlat_points(self): def get_lonlat_points(self):
lon = self.ori2geo_img[0, :, :][np.where(self._mask == 1)] lon = self.ori2geo_img[0, :, :][np.where(self._mask == 1)]
lat = self.ori2geo_img[1, :, :][np.where(self._mask == 1)] lat = self.ori2geo_img[1, :, :][np.where(self._mask == 1)]
@ -104,39 +107,118 @@ class TransImgL1A:
###################### ######################
# 插值方法 # 插值方法
###################### ######################
def init_trans_para(self, ori_sim_path, roi): def init_trans_para(self, sim_ori_path, roi):
"""裁剪L1a_img """裁剪L1a_img --裁剪L1A影像
--- 修改 ori_sim 变换为 sim_ori
Args: Args:
src_img_path (_type_): 原始L1A影像 src_img_path (_type_): 原始L1A影像
cuted_img_path (_type_): 待裁剪对象 cuted_img_path (_type_): 待裁剪对象
roi (_type_): 裁剪roi roi (_type_): 裁剪roi
""" """
ori2geo_img = ImageHandle.ImageHandler.get_data(ori_sim_path) ori2geo_img_height = ImageHandle.ImageHandler.get_img_height(sim_ori_path)
ori2geo_img_width = ImageHandle.ImageHandler.get_img_width(sim_ori_path)
ori2geo_img = ImageHandle.ImageHandler.get_data(sim_ori_path)
ori2geo_gt = ImageHandle.ImageHandler.get_geotransform(sim_ori_path)
point_list = np.array(roi) point_list = np.array(roi)
min_lon = np.nanmin(point_list[:, 0]) min_lon = np.nanmin(point_list[:, 0])
max_lon = np.nanmax(point_list[:, 0]) max_lon = np.nanmax(point_list[:, 0])
min_lat = np.nanmin(point_list[:, 1]) min_lat = np.nanmin(point_list[:, 1])
max_lat = np.nanmax(point_list[:, 1]) max_lat = np.nanmax(point_list[:, 1])
self._min_lon, self._max_lon, self._min_lat, self._max_lat = min_lon, max_lon, min_lat, max_lat self._min_lon, self._max_lon, self._min_lat, self._max_lat = min_lon, max_lon, min_lat, max_lat
# 根据 min_lon max_lon
# 根据 min_lat max_lat
r_c_list = np.where( (x_min, y_min) = ImageHandle.ImageHandler.lat_lon_to_pixel(sim_ori_path, (min_lon, min_lat))
(ori2geo_img[0, :, :] >= min_lon) & (ori2geo_img[0, :, :] <= max_lon) (x_max, y_max) = ImageHandle.ImageHandler.lat_lon_to_pixel(sim_ori_path, (max_lon, max_lat))
& (ori2geo_img[1, :, :] >= min_lat) & (ori2geo_img[1, :, :] <= max_lat)) #
if len(r_c_list) == 0 or r_c_list[0] == [] or r_c_list[1] == [] or np.array(r_c_list).size == 0: xmin = x_min if x_min < x_max else x_max
xmax = x_min if x_min > x_max else x_max
ymin = y_min if y_min < y_max else y_max
ymax = y_min if y_min > y_max else y_max
xmin = int(math.floor(xmin)) # 列号
xmax = int(math.ceil(xmax)) # 因为python 的索引机制
# xmax = int(math.ceil(xmax)) + 1 # 因为python 的索引机制
ymin = int(math.floor(ymin)) # 行号
ymax = int(math.ceil(ymax)) # 因为pytohn的索引机制
# ymax = int(math.ceil(ymax)) + 1 # 因为pytohn的索引机制
# 处理最大最小范围
xmin = 0 if 0 > xmin else xmin
ymin = 0 if 0 > ymin else ymin
xmax = ori2geo_img_width if ori2geo_img_width > xmax else xmax
ymax = ori2geo_img_height if ori2geo_img_height > ymax else ymax
# 判断条件
xmax = xmax + 1 if xmax == xmin else xmax
ymax = ymax + 1 if ymax == ymin else ymax
if ymax <= ymin or xmax <= xmin or ymax > ori2geo_img_height or xmax > ori2geo_img_width or xmin < 0 or ymin < 0 or xmin > ori2geo_img_width or ymin > ori2geo_img_height or ymax < 0 or xmax < 0:
msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data' msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data'
print(msg) print(msg)
else: else:
r_arr = ori2geo_img[0, ymin:ymax, xmin:xmax]
c_arr = ori2geo_img[1, ymin:ymax, xmin:xmax]
# 构建坐标矩阵
ori2geo_mask_r_count = ymax - ymin
ori2geo_mask_c_count = xmax - xmin
lon_lat_arr = np.ones((2, ori2geo_mask_r_count, ori2geo_mask_c_count))
col_arr = np.arange(xmin, xmax) * np.ones((ori2geo_mask_r_count, ori2geo_mask_c_count))
row_arr = ((np.arange(ymin, ymax)) * np.ones((ori2geo_mask_c_count, ori2geo_mask_r_count))).T
img_geotrans = ori2geo_gt
lon_arr = img_geotrans[0] + img_geotrans[1] * col_arr + img_geotrans[2] * row_arr
lat_arr = img_geotrans[3] + img_geotrans[4] * col_arr + img_geotrans[5] * row_arr
lon_lat_arr[0, :, :] = lon_arr
lon_lat_arr[1, :, :] = lat_arr
# print("csv_roi:") # print("csv_roi:")
# print(roi) # print(roi)
r_min = np.nanmin(r_c_list[0]) r_min = np.floor(np.nanmin(r_arr)) # 获取 L1A 的行列号范围
r_max = np.nanmax(r_c_list[0]) r_max = np.ceil(np.nanmax(r_arr)) + 1
c_min = np.nanmin(r_c_list[1]) c_min = np.floor(np.nanmin(c_arr))
c_max = np.nanmax(r_c_list[1]) c_max = np.ceil(np.nanmax(c_arr)) + 1
self.ori2geo_img = ori2geo_img[:, r_min:r_max + 1, c_min:c_max + 1]
# 开始调用组件 计算 # 判断是否越界
r_min = 0 if r_min < 0 else r_min
r_max = self.l1a_height if r_max > self.l1a_height else r_max
c_min = 0 if c_min < 0 else c_min
c_max = self.l1a_width if c_max > self.l1a_width else c_max
# 判断条件
r_max = r_max + 1 if r_min == r_max else r_max
c_max = c_max + 1 if c_min == c_max else c_max
if r_max <= r_min or c_max <= c_min or r_max > self.l1a_height or c_max > self.l1a_width or r_min < 0 or c_min < 0 or c_min > self.l1a_width or r_min > self.l1a_height or r_max < 0 or c_max < 0:
msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data'
else:
pass
mask_geo = SAR_GEO.cut_L1A_img(lon_lat_arr, point_list) # 在地理坐标系下裁剪对应影像
mask_geo = mask_geo.reshape(-1)
r_arr = r_arr.reshape(-1)
c_arr = c_arr.reshape(-1)
mask_geo_idx = np.where(mask_geo == 1)[0]
if mask_geo_idx.shape[0] == 0:
msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data'
print(msg)
else:
r_idx = r_arr[mask_geo_idx]
c_idx = c_arr[mask_geo_idx]
r_idx = r_idx - r_min # offset row
c_idx = c_idx - c_min # offset col
r_count = r_max - r_min # 行数
c_count = c_max - c_min # 列数
#
mask_l1a = np.zeros((r_count, c_count)) * np.nan # 创建目标大小的行列号
mask = SAR_GEO.gereratorMask(r_idx.astype(np.float64), c_idx.astype(np.float64).astype(np.float64),
mask_l1a) # 这个函数修改了
mask = SAR_GEO.cut_L1A_img(self.ori2geo_img.astype(np.float64), point_list)
self._begin_r = r_min self._begin_r = r_min
self._end_r = r_max self._end_r = r_max
self._begin_c = c_min self._begin_c = c_min
@ -146,7 +228,7 @@ class TransImgL1A:
def cut_L1A(self, in_path, out_path): def cut_L1A(self, in_path, out_path):
img = ImageHandle.ImageHandler.get_data(in_path) img = ImageHandle.ImageHandler.get_data(in_path)
if len(img.shape) == 3: if len(img.shape) == 3:
cut_img = img[:, self._begin_r:self._end_r + 1, self._begin_c:self._end_c + 1] cut_img = img[:, self._begin_r:self._end_r, self._begin_c:self._end_c]
cut_img[0, :, :] = cut_img[0, :, :] * self._mask cut_img[0, :, :] = cut_img[0, :, :] * self._mask
cut_img[1, :, :] = cut_img[1, :, :] * self._mask cut_img[1, :, :] = cut_img[1, :, :] * self._mask
ImageHandle.ImageHandler.write_img(out_path, '', [0, 0, 0, 0, 0, 0], cut_img) ImageHandle.ImageHandler.write_img(out_path, '', [0, 0, 0, 0, 0, 0], cut_img)
@ -303,8 +385,7 @@ class TransImgL1A:
f_c = scipy.interpolate.interp2d(r_c_list[:, 2], r_c_list[:, 3], r_c_list[:, 1], kind='linear') f_c = scipy.interpolate.interp2d(r_c_list[:, 2], r_c_list[:, 3], r_c_list[:, 1], kind='linear')
tar_get_r = f_r(p[0], p[1])[0] tar_get_r = f_r(p[0], p[1])[0]
tar_get_c = f_c(p[0], p[1])[0] tar_get_c = f_c(p[0], p[1])[0]
if tar_get_r < ori2geo_tif.shape[1] and tar_get_c < ori2geo_tif.shape[ if tar_get_r < ori2geo_tif.shape[1] and tar_get_c < ori2geo_tif.shape[2] and tar_get_r>=0 and tar_get_c>=0:
2] and tar_get_r >= 0 and tar_get_c >= 0:
lon_temp = ori2geo_tif[0, int(round(tar_get_r)), int(round(tar_get_c))] lon_temp = ori2geo_tif[0, int(round(tar_get_r)), int(round(tar_get_c))]
lon_lat = ori2geo_tif[1, int(round(tar_get_r)), int(round(tar_get_c))] lon_lat = ori2geo_tif[1, int(round(tar_get_r)), int(round(tar_get_c))]
# 增加条件筛选 # 增加条件筛选
@ -313,9 +394,12 @@ class TransImgL1A:
result.append([-1, -1]) result.append([-1, -1])
return result return result
def tran_lonlats_to_L1A_rowcols(self, meas_data, ori_sim_path): def tran_lonlats_to_L1A_rowcols(self, meas_data, ori_sim_path, row, col):
lonlats = [] lonlats = []
data_roi = [] data_roi = []
rowcols = []
measdata_list = []
data_sim = ImageHandle.ImageHandler.get_all_band_array(ori_sim_path)
for data in meas_data: for data in meas_data:
lon = float(data[1]) lon = float(data[1])
lat = float(data[2]) lat = float(data[2])
@ -323,12 +407,22 @@ class TransImgL1A:
lonlats.append([lon, lat]) lonlats.append([lon, lat])
data_roi.append(data) data_roi.append(data)
rowcols = self.tran_lonlats_to_rowcols(lonlats, ori_sim_path) for lonlat in lonlats:
measdata_list = [] (x, y) = ImageHandle.ImageHandler.lat_lon_to_pixel(ori_sim_path, lonlat)
rowcols.append([x, y])
for data, rowcol in zip(data_roi, rowcols): for data, rowcol in zip(data_roi, rowcols):
if (rowcol[0] != -1 and rowcol[1] != -1): img_x = round(data_sim[round(rowcol[1]), round(rowcol[0]), 0])
measdata_list.append( img_y = round(data_sim[round(rowcol[1]), round(rowcol[0]), 1])
[round(rowcol[0]) - self._begin_r, round(rowcol[1]) - self._begin_c, float(data[3])]) if (img_x > 0 and img_x < row and img_y > 0 and img_y < col):
measdata_list.append([img_x, img_y, float(data[3])])
# rowcols = self.tran_lonlats_to_rowcols(lonlats, ori_sim_path)
# measdata_list = []
# for data, rowcol in zip(data_roi, rowcols):
# if (rowcol[0] != -1 and rowcol[1] != -1):
# measdata_list.append(
# [round(rowcol[0]) - self._begin_r, round(rowcol[1]) - self._begin_c, float(data[3])])
return measdata_list return measdata_list
@staticmethod @staticmethod
@ -536,8 +630,8 @@ if __name__ == '__main__':
""" """
# roi_Extend = [[102.12, 33.879], [102.327, 33.879], [102.327, 33.66], [102.12, 31.45]] # roi_Extend = [[102.12, 33.879], [102.327, 33.879], [102.327, 33.66], [102.12, 31.45]]
ori_sim_data = ImageHandle.ImageHandler.get_data(ori_sim) ori_sim_data = ImageHandle.ImageHandler.get_data(ori_sim)
lon = ori_sim_data[0,:,:] lon = ori_sim_data[0, :, :]
lat = ori_sim_data[1,:,:] lat = ori_sim_data[1, :, :]
min_lon = np.nanmin(lon) min_lon = np.nanmin(lon)
max_lon = np.nanmax(lon) max_lon = np.nanmax(lon)
min_lat = np.nanmin(lat) min_lat = np.nanmin(lat)