更新裁剪方法,通过地距映射表进行斜距影像裁剪

dev
tian jiax 2023-11-22 09:28:51 +08:00
parent 762078fc90
commit c6cbe20b87
3 changed files with 170 additions and 39 deletions

View File

@ -2,7 +2,7 @@ import os
cimport cython # 必须导入
import numpy as np##必须为c类型和python类型的数据都申明一个np
cimport numpy as np # 必须为c类型和python类型的数据都申明一个np
from libc.math cimport pi
from libc.math cimport pi,ceil,floor
from scipy.interpolate import griddata
@ -86,17 +86,54 @@ cpdef np.ndarray[double,ndim=2] cut_L1A_img(np.ndarray[double,ndim=3] ori2geo_i
while i<height:
j=0
while j<width:
temp_p.x=ori2geo_img[0,i,j]
temp_p.y=ori2geo_img[1,i,j]
temp_p.x=ori2geo_img[0,i,j] # temp_p
temp_p.y=ori2geo_img[1,i,j] # temp_p
if rayCasting(temp_p,roi_list)==1:
mask[i,j]=1
else:
mask[i,j]=np.nan
j=j+1
i=i+1
return mask
cpdef np.ndarray[double,ndim=2] gereratorMask(np.ndarray[double,ndim=1] rlist,np.ndarray[double,ndim=1] clist,np.ndarray[double,ndim=2] mask):
cdef int rcount=rlist.shape[0]
cdef int ccount=clist.shape[0]
cdef int count=rcount if rcount<ccount else ccount
cdef int i=0
cdef int j=0
cdef int temp_row=0
cdef int temp_col=0
cdef int height=mask.shape[0]
cdef int width=mask.shape[1]
while i<count:
# 1
temp_row=int(ceil(rlist[i]))
temp_col=int(ceil(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
# 2
temp_row=int(floor(rlist[i]))
temp_col=int(ceil(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
# 3
temp_row=int(ceil(rlist[i]))
temp_col=int(floor(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
# 4
temp_row=int(floor(rlist[i]))
temp_col=int(floor(clist[i]))
if temp_row>=0 and temp_col>=0 and temp_row<height and temp_col<width:
mask[temp_row,temp_col]=1
i=i+1
return mask
cdef double distance_powe(Point p1,Point p2):

View File

@ -4,13 +4,14 @@ import numpy as np
import scipy
from scipy.interpolate import griddata, RegularGridInterpolator
import logging
import pyresample as pr
# import pyresample as pr
# 插值模块
from pyresample.bilinear import NumpyBilinearResampler
from pyresample import geometry
from pyresample.geometry import AreaDefinition
from osgeo import osr
import os
import math
# os.environ['PROJ_LIB'] = r"D:\Anaconda\envs\micro\Lib\site-packages\osgeo\data\proj"
@ -26,6 +27,7 @@ def griddata_geo(points, data, lon_grid, lat_grid, method, i, end_i):
grid_data = grid_data[:, :, 0]
return [i, end_i, grid_data]
def griddataBlock(start_x, len_x, start_y, len_y, grid_data_input, grid_x, grid_y, method):
grid_x = grid_x.reshape(-1)
grid_y = grid_y.reshape(-1)
@ -82,10 +84,12 @@ class polyfit2d_U:
class TransImgL1A:
def __init__(self, ori_sim_path, roi):
def __init__(self, ori_sim_path, roi, l1a_height, l1a_width):
self._begin_r, self._begin_c, self._end_r, self._end_c = 0, 0, 0, 0
self.ori2geo_img = None
self._mask = None
self.l1a_height = l1a_height
self.l1a_width = l1a_width
self._min_lon, self._max_lon, self._min_lat, self._max_lat = 0, 0, 0, 0
self.init_trans_para(ori_sim_path, roi)
@ -94,7 +98,6 @@ class TransImgL1A:
data = [(self._begin_r + row, self._begin_c + col) for (row, col) in zip(rowcol[0], rowcol[1])]
return data
def get_lonlat_points(self):
lon = self.ori2geo_img[0, :, :][np.where(self._mask == 1)]
lat = self.ori2geo_img[1, :, :][np.where(self._mask == 1)]
@ -104,49 +107,128 @@ class TransImgL1A:
######################
# 插值方法
######################
def init_trans_para(self, ori_sim_path, roi):
"""裁剪L1a_img
def init_trans_para(self, sim_ori_path, roi):
"""裁剪L1a_img --裁剪L1A影像
--- 修改 ori_sim 变换为 sim_ori
Args:
src_img_path (_type_): 原始L1A影像
cuted_img_path (_type_): 待裁剪对象
roi (_type_): 裁剪roi
"""
ori2geo_img = ImageHandle.ImageHandler.get_data(ori_sim_path)
ori2geo_img_height = ImageHandle.ImageHandler.get_img_height(sim_ori_path)
ori2geo_img_width = ImageHandle.ImageHandler.get_img_width(sim_ori_path)
ori2geo_img = ImageHandle.ImageHandler.get_data(sim_ori_path)
ori2geo_gt = ImageHandle.ImageHandler.get_geotransform(sim_ori_path)
point_list = np.array(roi)
min_lon = np.nanmin(point_list[:, 0])
max_lon = np.nanmax(point_list[:, 0])
min_lat = np.nanmin(point_list[:, 1])
max_lat = np.nanmax(point_list[:, 1])
self._min_lon, self._max_lon, self._min_lat, self._max_lat = min_lon, max_lon, min_lat, max_lat
# 根据 min_lon max_lon
# 根据 min_lat max_lat
r_c_list = np.where(
(ori2geo_img[0, :, :] >= min_lon) & (ori2geo_img[0, :, :] <= max_lon)
& (ori2geo_img[1, :, :] >= min_lat) & (ori2geo_img[1, :, :] <= max_lat)) #
(x_min, y_min) = ImageHandle.ImageHandler.lat_lon_to_pixel(sim_ori_path, (min_lon, min_lat))
(x_max, y_max) = ImageHandle.ImageHandler.lat_lon_to_pixel(sim_ori_path, (max_lon, max_lat))
if len(r_c_list) == 0 or r_c_list[0] == [] or r_c_list[1] == [] or np.array(r_c_list).size == 0:
xmin = x_min if x_min < x_max else x_max
xmax = x_min if x_min > x_max else x_max
ymin = y_min if y_min < y_max else y_max
ymax = y_min if y_min > y_max else y_max
xmin = int(math.floor(xmin)) # 列号
xmax = int(math.ceil(xmax)) # 因为python 的索引机制
# xmax = int(math.ceil(xmax)) + 1 # 因为python 的索引机制
ymin = int(math.floor(ymin)) # 行号
ymax = int(math.ceil(ymax)) # 因为pytohn的索引机制
# ymax = int(math.ceil(ymax)) + 1 # 因为pytohn的索引机制
# 处理最大最小范围
xmin = 0 if 0 > xmin else xmin
ymin = 0 if 0 > ymin else ymin
xmax = ori2geo_img_width if ori2geo_img_width > xmax else xmax
ymax = ori2geo_img_height if ori2geo_img_height > ymax else ymax
# 判断条件
xmax = xmax + 1 if xmax == xmin else xmax
ymax = ymax + 1 if ymax == ymin else ymax
if ymax <= ymin or xmax <= xmin or ymax > ori2geo_img_height or xmax > ori2geo_img_width or xmin < 0 or ymin < 0 or xmin > ori2geo_img_width or ymin > ori2geo_img_height or ymax < 0 or xmax < 0:
msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data'
print(msg)
else:
r_arr = ori2geo_img[0, ymin:ymax, xmin:xmax]
c_arr = ori2geo_img[1, ymin:ymax, xmin:xmax]
# 构建坐标矩阵
ori2geo_mask_r_count = ymax - ymin
ori2geo_mask_c_count = xmax - xmin
lon_lat_arr = np.ones((2, ori2geo_mask_r_count, ori2geo_mask_c_count))
col_arr = np.arange(xmin, xmax) * np.ones((ori2geo_mask_r_count, ori2geo_mask_c_count))
row_arr = ((np.arange(ymin, ymax)) * np.ones((ori2geo_mask_c_count, ori2geo_mask_r_count))).T
img_geotrans = ori2geo_gt
lon_arr = img_geotrans[0] + img_geotrans[1] * col_arr + img_geotrans[2] * row_arr
lat_arr = img_geotrans[3] + img_geotrans[4] * col_arr + img_geotrans[5] * row_arr
lon_lat_arr[0, :, :] = lon_arr
lon_lat_arr[1, :, :] = lat_arr
# print("csv_roi:")
# print(roi)
r_min = np.nanmin(r_c_list[0])
r_max = np.nanmax(r_c_list[0])
c_min = np.nanmin(r_c_list[1])
c_max = np.nanmax(r_c_list[1])
self.ori2geo_img = ori2geo_img[:, r_min:r_max + 1, c_min:c_max + 1]
# 开始调用组件 计算
r_min = np.floor(np.nanmin(r_arr)) # 获取 L1A 的行列号范围
r_max = np.ceil(np.nanmax(r_arr)) + 1
c_min = np.floor(np.nanmin(c_arr))
c_max = np.ceil(np.nanmax(c_arr)) + 1
mask = SAR_GEO.cut_L1A_img(self.ori2geo_img.astype(np.float64), point_list)
self._begin_r = r_min
self._end_r = r_max
self._begin_c = c_min
self._end_c = c_max
self._mask = mask
# 判断是否越界
r_min = 0 if r_min < 0 else r_min
r_max = self.l1a_height if r_max > self.l1a_height else r_max
c_min = 0 if c_min < 0 else c_min
c_max = self.l1a_width if c_max > self.l1a_width else c_max
# 判断条件
r_max = r_max + 1 if r_min == r_max else r_max
c_max = c_max + 1 if c_min == c_max else c_max
if r_max <= r_min or c_max <= c_min or r_max > self.l1a_height or c_max > self.l1a_width or r_min < 0 or c_min < 0 or c_min > self.l1a_width or r_min > self.l1a_height or r_max < 0 or c_max < 0:
msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data'
else:
pass
mask_geo = SAR_GEO.cut_L1A_img(lon_lat_arr, point_list) # 在地理坐标系下裁剪对应影像
mask_geo = mask_geo.reshape(-1)
r_arr = r_arr.reshape(-1)
c_arr = c_arr.reshape(-1)
mask_geo_idx = np.where(mask_geo == 1)[0]
if mask_geo_idx.shape[0] == 0:
msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data'
print(msg)
else:
r_idx = r_arr[mask_geo_idx]
c_idx = c_arr[mask_geo_idx]
r_idx = r_idx - r_min # offset row
c_idx = c_idx - c_min # offset col
r_count = r_max - r_min # 行数
c_count = c_max - c_min # 列数
#
mask_l1a = np.zeros((r_count, c_count)) * np.nan # 创建目标大小的行列号
mask = SAR_GEO.gereratorMask(r_idx.astype(np.float64), c_idx.astype(np.float64).astype(np.float64),
mask_l1a) # 这个函数修改了
self._begin_r = r_min
self._end_r = r_max
self._begin_c = c_min
self._end_c = c_max
self._mask = mask
def cut_L1A(self, in_path, out_path):
img = ImageHandle.ImageHandler.get_data(in_path)
if len(img.shape) == 3:
cut_img = img[:, self._begin_r:self._end_r + 1, self._begin_c:self._end_c + 1]
cut_img = img[:, self._begin_r:self._end_r, self._begin_c:self._end_c]
cut_img[0, :, :] = cut_img[0, :, :] * self._mask
cut_img[1, :, :] = cut_img[1, :, :] * self._mask
ImageHandle.ImageHandler.write_img(out_path, '', [0, 0, 0, 0, 0, 0], cut_img)
@ -233,7 +315,7 @@ class TransImgL1A:
if is_class:
ori2geo_tif = np.round(ori2geo_tif).astype(np.int32)
mask = (ori2geo_tif[0, :, :] >= 0) & (ori2geo_tif[0, :, :] < width) & (ori2geo_tif[1, :, :] >= 0) & (
ori2geo_tif[1, :, :] < height)
ori2geo_tif[1, :, :] < height)
ori2geo_tif[0, :, :] = ori2geo_tif[0, :, :] * mask
ori2geo_tif[1, :, :] = ori2geo_tif[1, :, :] * mask
geo_tif_shape = geo_tif.shape
@ -248,7 +330,7 @@ class TransImgL1A:
return geo_tif_l1a
else: # 数值性插值
mask = (ori2geo_tif[0, :, :] > 0) & (ori2geo_tif[0, :, :] < width - 1) & (ori2geo_tif[1, :, :] > 0) & (
ori2geo_tif[1, :, :] < height - 1)
ori2geo_tif[1, :, :] < height - 1)
one_ids = np.where(mask == 1)
x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
result_data = self.grid_interp_to_station([y.reshape(-1), x.reshape(-1), geo_tif.reshape(-1)],
@ -303,8 +385,7 @@ class TransImgL1A:
f_c = scipy.interpolate.interp2d(r_c_list[:, 2], r_c_list[:, 3], r_c_list[:, 1], kind='linear')
tar_get_r = f_r(p[0], p[1])[0]
tar_get_c = f_c(p[0], p[1])[0]
if tar_get_r < ori2geo_tif.shape[1] and tar_get_c < ori2geo_tif.shape[
2] and tar_get_r >= 0 and tar_get_c >= 0:
if tar_get_r < ori2geo_tif.shape[1] and tar_get_c < ori2geo_tif.shape[2] and tar_get_r>=0 and tar_get_c>=0:
lon_temp = ori2geo_tif[0, int(round(tar_get_r)), int(round(tar_get_c))]
lon_lat = ori2geo_tif[1, int(round(tar_get_r)), int(round(tar_get_c))]
# 增加条件筛选
@ -313,9 +394,12 @@ class TransImgL1A:
result.append([-1, -1])
return result
def tran_lonlats_to_L1A_rowcols(self, meas_data, ori_sim_path):
def tran_lonlats_to_L1A_rowcols(self, meas_data, ori_sim_path, row, col):
lonlats = []
data_roi = []
rowcols = []
measdata_list = []
data_sim = ImageHandle.ImageHandler.get_all_band_array(ori_sim_path)
for data in meas_data:
lon = float(data[1])
lat = float(data[2])
@ -323,12 +407,22 @@ class TransImgL1A:
lonlats.append([lon, lat])
data_roi.append(data)
rowcols = self.tran_lonlats_to_rowcols(lonlats, ori_sim_path)
measdata_list = []
for lonlat in lonlats:
(x, y) = ImageHandle.ImageHandler.lat_lon_to_pixel(ori_sim_path, lonlat)
rowcols.append([x, y])
for data, rowcol in zip(data_roi, rowcols):
if (rowcol[0] != -1 and rowcol[1] != -1):
measdata_list.append(
[round(rowcol[0]) - self._begin_r, round(rowcol[1]) - self._begin_c, float(data[3])])
img_x = round(data_sim[round(rowcol[1]), round(rowcol[0]), 0])
img_y = round(data_sim[round(rowcol[1]), round(rowcol[0]), 1])
if (img_x > 0 and img_x < row and img_y > 0 and img_y < col):
measdata_list.append([img_x, img_y, float(data[3])])
# rowcols = self.tran_lonlats_to_rowcols(lonlats, ori_sim_path)
# measdata_list = []
# for data, rowcol in zip(data_roi, rowcols):
# if (rowcol[0] != -1 and rowcol[1] != -1):
# measdata_list.append(
# [round(rowcol[0]) - self._begin_r, round(rowcol[1]) - self._begin_c, float(data[3])])
return measdata_list
@staticmethod
@ -536,8 +630,8 @@ if __name__ == '__main__':
"""
# roi_Extend = [[102.12, 33.879], [102.327, 33.879], [102.327, 33.66], [102.12, 31.45]]
ori_sim_data = ImageHandle.ImageHandler.get_data(ori_sim)
lon = ori_sim_data[0,:,:]
lat = ori_sim_data[1,:,:]
lon = ori_sim_data[0, :, :]
lat = ori_sim_data[1, :, :]
min_lon = np.nanmin(lon)
max_lon = np.nanmax(lon)
min_lat = np.nanmin(lat)