diff --git a/tool/algorithm/transforml1a/SAR_GEO.cp38-win_amd64.pyd b/tool/algorithm/transforml1a/SAR_GEO.cp38-win_amd64.pyd index 1482c62c..642f519a 100644 Binary files a/tool/algorithm/transforml1a/SAR_GEO.cp38-win_amd64.pyd and b/tool/algorithm/transforml1a/SAR_GEO.cp38-win_amd64.pyd differ diff --git a/tool/algorithm/transforml1a/SAR_geo/SAR_GEO.pyx b/tool/algorithm/transforml1a/SAR_geo/SAR_GEO.pyx index 619d83c9..c443ac33 100644 --- a/tool/algorithm/transforml1a/SAR_geo/SAR_GEO.pyx +++ b/tool/algorithm/transforml1a/SAR_geo/SAR_GEO.pyx @@ -2,7 +2,7 @@ import os cimport cython # 必须导入 import numpy as np##必须为c类型和python类型的数据都申明一个np cimport numpy as np # 必须为c类型和python类型的数据都申明一个np -from libc.math cimport pi +from libc.math cimport pi,ceil,floor from scipy.interpolate import griddata @@ -86,17 +86,54 @@ cpdef np.ndarray[double,ndim=2] cut_L1A_img(np.ndarray[double,ndim=3] ori2geo_i while i=0 and temp_col>=0 and temp_row=0 and temp_col>=0 and temp_row=0 and temp_col>=0 and temp_row=0 and temp_col>=0 and temp_row= min_lon) & (ori2geo_img[0, :, :] <= max_lon) - & (ori2geo_img[1, :, :] >= min_lat) & (ori2geo_img[1, :, :] <= max_lat)) # + (x_min, y_min) = ImageHandle.ImageHandler.lat_lon_to_pixel(sim_ori_path, (min_lon, min_lat)) + (x_max, y_max) = ImageHandle.ImageHandler.lat_lon_to_pixel(sim_ori_path, (max_lon, max_lat)) - if len(r_c_list) == 0 or r_c_list[0] == [] or r_c_list[1] == [] or np.array(r_c_list).size == 0: + xmin = x_min if x_min < x_max else x_max + xmax = x_min if x_min > x_max else x_max + + ymin = y_min if y_min < y_max else y_max + ymax = y_min if y_min > y_max else y_max + + xmin = int(math.floor(xmin)) # 列号 + xmax = int(math.ceil(xmax)) # 因为python 的索引机制 + # xmax = int(math.ceil(xmax)) + 1 # 因为python 的索引机制 + ymin = int(math.floor(ymin)) # 行号 + ymax = int(math.ceil(ymax)) # 因为pytohn的索引机制 + # ymax = int(math.ceil(ymax)) + 1 # 因为pytohn的索引机制 + + # 处理最大最小范围 + xmin = 0 if 0 > xmin else xmin + ymin = 0 if 0 > ymin else ymin + xmax = ori2geo_img_width if ori2geo_img_width > xmax else xmax + ymax = ori2geo_img_height if ori2geo_img_height > ymax else ymax + + # 判断条件 + xmax = xmax + 1 if xmax == xmin else xmax + ymax = ymax + 1 if ymax == ymin else ymax + + if ymax <= ymin or xmax <= xmin or ymax > ori2geo_img_height or xmax > ori2geo_img_width or xmin < 0 or ymin < 0 or xmin > ori2geo_img_width or ymin > ori2geo_img_height or ymax < 0 or xmax < 0: msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data!' print(msg) else: + r_arr = ori2geo_img[0, ymin:ymax, xmin:xmax] + c_arr = ori2geo_img[1, ymin:ymax, xmin:xmax] + + # 构建坐标矩阵 + ori2geo_mask_r_count = ymax - ymin + ori2geo_mask_c_count = xmax - xmin + + lon_lat_arr = np.ones((2, ori2geo_mask_r_count, ori2geo_mask_c_count)) + col_arr = np.arange(xmin, xmax) * np.ones((ori2geo_mask_r_count, ori2geo_mask_c_count)) + row_arr = ((np.arange(ymin, ymax)) * np.ones((ori2geo_mask_c_count, ori2geo_mask_r_count))).T + + img_geotrans = ori2geo_gt + lon_arr = img_geotrans[0] + img_geotrans[1] * col_arr + img_geotrans[2] * row_arr + lat_arr = img_geotrans[3] + img_geotrans[4] * col_arr + img_geotrans[5] * row_arr + lon_lat_arr[0, :, :] = lon_arr + lon_lat_arr[1, :, :] = lat_arr + # print("csv_roi:") # print(roi) - r_min = np.nanmin(r_c_list[0]) - r_max = np.nanmax(r_c_list[0]) - c_min = np.nanmin(r_c_list[1]) - c_max = np.nanmax(r_c_list[1]) - self.ori2geo_img = ori2geo_img[:, r_min:r_max + 1, c_min:c_max + 1] - # 开始调用组件 计算 + r_min = np.floor(np.nanmin(r_arr)) # 获取 L1A 的行列号范围 + r_max = np.ceil(np.nanmax(r_arr)) + 1 + c_min = np.floor(np.nanmin(c_arr)) + c_max = np.ceil(np.nanmax(c_arr)) + 1 - mask = SAR_GEO.cut_L1A_img(self.ori2geo_img.astype(np.float64), point_list) - self._begin_r = r_min - self._end_r = r_max - self._begin_c = c_min - self._end_c = c_max - self._mask = mask + # 判断是否越界 + r_min = 0 if r_min < 0 else r_min + r_max = self.l1a_height if r_max > self.l1a_height else r_max + c_min = 0 if c_min < 0 else c_min + c_max = self.l1a_width if c_max > self.l1a_width else c_max + + # 判断条件 + r_max = r_max + 1 if r_min == r_max else r_max + c_max = c_max + 1 if c_min == c_max else c_max + if r_max <= r_min or c_max <= c_min or r_max > self.l1a_height or c_max > self.l1a_width or r_min < 0 or c_min < 0 or c_min > self.l1a_width or r_min > self.l1a_height or r_max < 0 or c_max < 0: + msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data!' + else: + pass + mask_geo = SAR_GEO.cut_L1A_img(lon_lat_arr, point_list) # 在地理坐标系下裁剪对应影像 + + mask_geo = mask_geo.reshape(-1) + r_arr = r_arr.reshape(-1) + c_arr = c_arr.reshape(-1) + + mask_geo_idx = np.where(mask_geo == 1)[0] + if mask_geo_idx.shape[0] == 0: + msg = 'csv_roi:' + str(roi) + 'not in box,please revise csv data!' + print(msg) + else: + r_idx = r_arr[mask_geo_idx] + c_idx = c_arr[mask_geo_idx] + + r_idx = r_idx - r_min # offset row + c_idx = c_idx - c_min # offset col + r_count = r_max - r_min # 行数 + c_count = c_max - c_min # 列数 + + # + mask_l1a = np.zeros((r_count, c_count)) * np.nan # 创建目标大小的行列号 + mask = SAR_GEO.gereratorMask(r_idx.astype(np.float64), c_idx.astype(np.float64).astype(np.float64), + mask_l1a) # 这个函数修改了 + + self._begin_r = r_min + self._end_r = r_max + self._begin_c = c_min + self._end_c = c_max + self._mask = mask def cut_L1A(self, in_path, out_path): img = ImageHandle.ImageHandler.get_data(in_path) if len(img.shape) == 3: - cut_img = img[:, self._begin_r:self._end_r + 1, self._begin_c:self._end_c + 1] + cut_img = img[:, self._begin_r:self._end_r, self._begin_c:self._end_c] cut_img[0, :, :] = cut_img[0, :, :] * self._mask cut_img[1, :, :] = cut_img[1, :, :] * self._mask ImageHandle.ImageHandler.write_img(out_path, '', [0, 0, 0, 0, 0, 0], cut_img) @@ -233,7 +315,7 @@ class TransImgL1A: if is_class: ori2geo_tif = np.round(ori2geo_tif).astype(np.int32) mask = (ori2geo_tif[0, :, :] >= 0) & (ori2geo_tif[0, :, :] < width) & (ori2geo_tif[1, :, :] >= 0) & ( - ori2geo_tif[1, :, :] < height) + ori2geo_tif[1, :, :] < height) ori2geo_tif[0, :, :] = ori2geo_tif[0, :, :] * mask ori2geo_tif[1, :, :] = ori2geo_tif[1, :, :] * mask geo_tif_shape = geo_tif.shape @@ -248,7 +330,7 @@ class TransImgL1A: return geo_tif_l1a else: # 数值性插值 mask = (ori2geo_tif[0, :, :] > 0) & (ori2geo_tif[0, :, :] < width - 1) & (ori2geo_tif[1, :, :] > 0) & ( - ori2geo_tif[1, :, :] < height - 1) + ori2geo_tif[1, :, :] < height - 1) one_ids = np.where(mask == 1) x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) result_data = self.grid_interp_to_station([y.reshape(-1), x.reshape(-1), geo_tif.reshape(-1)], @@ -303,8 +385,7 @@ class TransImgL1A: f_c = scipy.interpolate.interp2d(r_c_list[:, 2], r_c_list[:, 3], r_c_list[:, 1], kind='linear') tar_get_r = f_r(p[0], p[1])[0] tar_get_c = f_c(p[0], p[1])[0] - if tar_get_r < ori2geo_tif.shape[1] and tar_get_c < ori2geo_tif.shape[ - 2] and tar_get_r >= 0 and tar_get_c >= 0: + if tar_get_r < ori2geo_tif.shape[1] and tar_get_c < ori2geo_tif.shape[2] and tar_get_r>=0 and tar_get_c>=0: lon_temp = ori2geo_tif[0, int(round(tar_get_r)), int(round(tar_get_c))] lon_lat = ori2geo_tif[1, int(round(tar_get_r)), int(round(tar_get_c))] # 增加条件筛选 @@ -313,9 +394,12 @@ class TransImgL1A: result.append([-1, -1]) return result - def tran_lonlats_to_L1A_rowcols(self, meas_data, ori_sim_path): + def tran_lonlats_to_L1A_rowcols(self, meas_data, ori_sim_path, row, col): lonlats = [] data_roi = [] + rowcols = [] + measdata_list = [] + data_sim = ImageHandle.ImageHandler.get_all_band_array(ori_sim_path) for data in meas_data: lon = float(data[1]) lat = float(data[2]) @@ -323,12 +407,22 @@ class TransImgL1A: lonlats.append([lon, lat]) data_roi.append(data) - rowcols = self.tran_lonlats_to_rowcols(lonlats, ori_sim_path) - measdata_list = [] + for lonlat in lonlats: + (x, y) = ImageHandle.ImageHandler.lat_lon_to_pixel(ori_sim_path, lonlat) + rowcols.append([x, y]) + for data, rowcol in zip(data_roi, rowcols): - if (rowcol[0] != -1 and rowcol[1] != -1): - measdata_list.append( - [round(rowcol[0]) - self._begin_r, round(rowcol[1]) - self._begin_c, float(data[3])]) + img_x = round(data_sim[round(rowcol[1]), round(rowcol[0]), 0]) + img_y = round(data_sim[round(rowcol[1]), round(rowcol[0]), 1]) + if (img_x > 0 and img_x < row and img_y > 0 and img_y < col): + measdata_list.append([img_x, img_y, float(data[3])]) + + # rowcols = self.tran_lonlats_to_rowcols(lonlats, ori_sim_path) + # measdata_list = [] + # for data, rowcol in zip(data_roi, rowcols): + # if (rowcol[0] != -1 and rowcol[1] != -1): + # measdata_list.append( + # [round(rowcol[0]) - self._begin_r, round(rowcol[1]) - self._begin_c, float(data[3])]) return measdata_list @staticmethod @@ -536,8 +630,8 @@ if __name__ == '__main__': """ # roi_Extend = [[102.12, 33.879], [102.327, 33.879], [102.327, 33.66], [102.12, 31.45]] ori_sim_data = ImageHandle.ImageHandler.get_data(ori_sim) - lon = ori_sim_data[0,:,:] - lat = ori_sim_data[1,:,:] + lon = ori_sim_data[0, :, :] + lat = ori_sim_data[1, :, :] min_lon = np.nanmin(lon) max_lon = np.nanmax(lon) min_lat = np.nanmin(lat)