diff --git a/.gitignore b/.gitignore index 0d7da3c..e4bb0ee 100644 --- a/.gitignore +++ b/.gitignore @@ -713,3 +713,7 @@ FodyWeavers.xsd hs_err_pid* replay_pid* +# 其他数据处理程序 +ProgramEXE/ + + diff --git a/tools/ImageDataOperator/ImageHandle.py b/tools/ImageDataOperator/ImageHandle.py new file mode 100644 index 0000000..4e1cb38 --- /dev/null +++ b/tools/ImageDataOperator/ImageHandle.py @@ -0,0 +1,939 @@ +""" +@Project :microproduct +@File :ImageHandle.py +@Function :实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能 +@Author :LMM +@Date :2021/10/19 14:39 +@Version :1.0.0 +""" +import os +from PIL import Image +import time +from osgeo import gdal +from osgeo import osr +import numpy as np +from PIL import Image +import cv2 +from xml.etree.ElementTree import ElementTree, Element +import math + + + +class ImageHandler: + """ + 影像读取、编辑、保存 + """ + def __init__(self): + pass + @staticmethod + def get_dataset(filename): + """ + :param filename: tif路径 + :return: 图像句柄 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + return dataset + + def get_scope(self, filename): + """ + :param filename: tif路径 + :return: 图像范围 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + im_scope = self.cal_img_scope(dataset) + del dataset + return im_scope + + @staticmethod + def get_projection(filename): + """ + :param filename: tif路径 + :return: 地图投影信息 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + im_proj = dataset.GetProjection() + del dataset + return im_proj + + @staticmethod + def get_geotransform(filename): + """ + :param filename: tif路径 + :return: 从图像坐标空间(行、列),也称为(像素、线)到地理参考坐标空间(投影或地理坐标)的仿射变换 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + geotransform = dataset.GetGeoTransform() + del dataset + return geotransform + + def get_invgeotransform(filename): + """ + :param filename: tif路径 + :return: 从地理参考坐标空间(投影或地理坐标)的到图像坐标空间(行、列 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + geotransform = dataset.GetGeoTransform() + geotransform=gdal.InvGeoTransform(geotransform) + del dataset + return geotransform + + @staticmethod + def get_bands(filename): + """ + :param filename: tif路径 + :return: 影像的波段数 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + bands = dataset.RasterCount + del dataset + return bands + + @staticmethod + def geo2lonlat(dataset, x, y): + """ + 将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定) + :param dataset: GDAL地理数据 + :param x: 投影坐标x + :param y: 投影坐标y + :return: 投影坐标(x, y)对应的经纬度坐标(lon, lat) + """ + prosrs = osr.SpatialReference() + prosrs.ImportFromWkt(dataset.GetProjection()) + geosrs = prosrs.CloneGeogCS() + ct = osr.CoordinateTransformation(prosrs, geosrs) + coords = ct.TransformPoint(x, y) + return coords[:2] + + @staticmethod + def get_band_array(filename, num=1): + """ + :param filename: tif路径 + :param num: 波段序号 + :return: 对应波段的矩阵数据 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + bands = dataset.GetRasterBand(num) + array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize) + + # if 'int' in str(array.dtype): + # array[np.where(array == -9999)] = np.inf + # else: + # array[np.where(array < -9000.0)] = np.nan + + del dataset + return array + + @staticmethod + def write_imgArray(filename, im_data, no_data='null'): + """ + 影像保存 + :param filename: 保存的路径 + :param im_proj: + :param im_geotrans: + :param im_data: + :param no_data: 把无效值设置为 nodata + :return: + """ + + gdal_dtypes = { + 'int8': gdal.GDT_Byte, + 'unit16': gdal.GDT_UInt16, + 'int16': gdal.GDT_Int16, + 'unit32': gdal.GDT_UInt32, + 'int32': gdal.GDT_Int32, + 'float32': gdal.GDT_Float32, + 'float64': gdal.GDT_Float64, + } + if not gdal_dtypes.get(im_data.dtype.name, None) is None: + datatype = gdal_dtypes[im_data.dtype.name] + else: + datatype = gdal.GDT_Float32 + + # 判读数组维数 + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + else: + im_bands, (im_height, im_width) = 1, im_data.shape + + # 创建文件 + if os.path.exists(os.path.split(filename)[0]) is False: + os.makedirs(os.path.split(filename)[0]) + + driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 + dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) + + if im_bands == 1: + # outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据 + outband = dataset.GetRasterBand(1) + outband.WriteArray(im_data) + if no_data != 'null': + outband.SetNoDataValue(no_data) + outband.FlushCache() + else: + for i in range(im_bands): + outband = dataset.GetRasterBand(1 + i) + outband.WriteArray(im_data[i]) + outband.FlushCache() + # outRaster.GetRasterBand(i + 1).WriteArray(array[i]) + del dataset + + @staticmethod + def get_data(filename): + """ + :param filename: tif路径 + :return: 获取所有波段的数据 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + im_width = dataset.RasterXSize + im_height = dataset.RasterYSize + im_data = dataset.ReadAsArray(0, 0, im_width, im_height) + del dataset + return im_data + + @staticmethod + def get_scopes(ori_sim): + ori_sim_data = ImageHandler.get_data(ori_sim) + lon = ori_sim_data[0, :, :] + lat = ori_sim_data[1, :, :] + + min_lon = np.nanmin(lon) + max_lon = np.nanmax(lon) + min_lat = np.nanmin(lat) + max_lat = np.nanmax(lat) + + scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]] + return scopes + + @staticmethod + def get_dataset(filename): + """ + :param filename: tif路径 + :return: 获取所有波段的数据 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + return dataset + + @staticmethod + def get_all_band_array(filename): + """ + (大气延迟算法) + 将ERA-5影像所有波段存为一个数组, 波段数在第三维度 get_data()->(37,8,8) + :param filename: 影像路径 get_all_band_array ->(8,8,37) + :return: 影像数组 + """ + dataset = gdal.Open(filename) + x_size = dataset.RasterXSize + y_size = dataset.RasterYSize + nums = dataset.RasterCount + array = np.zeros((y_size, x_size, nums), dtype=float) + if nums == 1: + bands_0 = dataset.GetRasterBand(1) + array = bands_0.ReadAsArray(0, 0, x_size, y_size) + else: + for i in range(0, nums): + bands = dataset.GetRasterBand(i+1) + arr = bands.ReadAsArray(0, 0, x_size, y_size) + array[:, :, i] = arr + return array + + @staticmethod + def get_img_width(filename): + """ + :param filename: tif路径 + :return: 影像宽度 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + width = dataset.RasterXSize + + del dataset + return width + + @staticmethod + def get_img_height(filename): + """ + :param filename: tif路径 + :return: 影像高度 + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + height = dataset.RasterYSize + del dataset + return height + + @staticmethod + def read_img(filename): + """ + 影像读取 + :param filename: + :return: + """ + gdal.AllRegister() + img_dataset = gdal.Open(filename) # 打开文件 + + if img_dataset is None: + msg = 'Could not open ' + filename + print(msg) + return None, None, None + + im_proj = img_dataset.GetProjection() # 地图投影信息 + if im_proj is None: + return None, None, None + im_geotrans = img_dataset.GetGeoTransform() # 仿射矩阵 + + im_width = img_dataset.RasterXSize # 栅格矩阵的行数 + im_height = img_dataset.RasterYSize # 栅格矩阵的行数 + im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height) + del img_dataset + return im_proj, im_geotrans, im_arr + + def cal_img_scope(self, dataset): + """ + 计算影像的地理坐标范围 + 根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) + :param dataset :GDAL地理数据 + :return: list[point_upleft, point_upright, point_downleft, point_downright] + """ + if dataset is None: + return None + + img_geotrans = dataset.GetGeoTransform() + if img_geotrans is None: + return None + + width = dataset.RasterXSize # 栅格矩阵的列数 + height = dataset.RasterYSize # 栅格矩阵的行数 + + point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0) + point_upright = self.trans_rowcol2geo(img_geotrans, width, 0) + point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height) + point_downright = self.trans_rowcol2geo(img_geotrans, width, height) + + return [point_upleft, point_upright, point_downleft, point_downright] + + @staticmethod + def get_scope_ori_sim(filename): + """ + 计算影像的地理坐标范围 + 根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) + :param dataset :GDAL地理数据 + :return: list[point_upleft, point_upright, point_downleft, point_downright] + """ + gdal.AllRegister() + dataset = gdal.Open(filename) + if dataset is None: + return None + + width = dataset.RasterXSize # 栅格矩阵的列数 + height = dataset.RasterYSize # 栅格矩阵的行数 + + band1 = dataset.GetRasterBand(1) + array1 = band1.ReadAsArray(0, 0, band1.XSize, band1.YSize) + + band2 = dataset.GetRasterBand(2) + array2 = band2.ReadAsArray(0, 0, band2.XSize, band2.YSize) + + if array1[0, 0] < array1[0, width-1]: + point_upleft = [array1[0, 0], array2[0, 0]] + point_upright = [array1[0, width-1], array2[0, width-1]] + else: + point_upright = [array1[0, 0], array2[0, 0]] + point_upleft = [array1[0, width-1], array2[0, width-1]] + + + if array1[height-1, 0] < array1[height-1, width-1]: + point_downleft = [array1[height - 1, 0], array2[height - 1, 0]] + point_downright = [array1[height - 1, width - 1], array2[height - 1, width - 1]] + else: + point_downright = [array1[height - 1, 0], array2[height - 1, 0]] + point_downleft = [array1[height - 1, width - 1], array2[height - 1, width - 1]] + + + if(array2[0, 0] < array2[height - 1, 0]): + #上下调换顺序 + tmp1 = point_upleft + point_upleft = point_downleft + point_downleft = tmp1 + + tmp2 = point_upright + point_upright = point_downright + point_downright = tmp2 + + return [point_upleft, point_upright, point_downleft, point_downright] + + + @staticmethod + def trans_rowcol2geo(img_geotrans,img_col, img_row): + """ + 据GDAL的六参数模型仿射矩阵将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换) + :param img_geotrans: 仿射矩阵 + :param img_col:图像纵坐标 + :param img_row:图像横坐标 + :return: [geo_x,geo_y] + """ + geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2] * img_row + geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5] * img_row + return [geo_x, geo_y] + + @staticmethod + def write_era_into_img(filename, im_proj, im_geotrans, im_data): + """ + 影像保存 + :param filename: + :param im_proj: + :param im_geotrans: + :param im_data: + :return: + """ + gdal_dtypes = { + 'int8': gdal.GDT_Byte, + 'unit16': gdal.GDT_UInt16, + 'int16': gdal.GDT_Int16, + 'unit32': gdal.GDT_UInt32, + 'int32': gdal.GDT_Int32, + 'float32': gdal.GDT_Float32, + 'float64': gdal.GDT_Float64, + } + if not gdal_dtypes.get(im_data.dtype.name, None) is None: + datatype = gdal_dtypes[im_data.dtype.name] + else: + datatype = gdal.GDT_Float32 + + # 判读数组维数 + if len(im_data.shape) == 3: + im_height, im_width, im_bands = im_data.shape # shape[0] 行数 + else: + im_bands, (im_height, im_width) = 1, im_data.shape + + # 创建文件 + if os.path.exists(os.path.split(filename)[0]) is False: + os.makedirs(os.path.split(filename)[0]) + + driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 + dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + dataset.SetProjection(im_proj) # 写入投影 + + if im_bands == 1: + dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 + else: + for i in range(im_bands): + dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, i]) + # dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) + del dataset + + # 写GeoTiff文件 + + @staticmethod + def lat_lon_to_pixel(raster_dataset_path, location): + """From zacharybears.com/using-python-to-translate-latlon-locations-to-pixels-on-a-geotiff/.""" + gdal.AllRegister() + raster_dataset = gdal.Open(raster_dataset_path) + if raster_dataset is None: + return None + ds = raster_dataset + gt = ds.GetGeoTransform() + srs = osr.SpatialReference() + srs.ImportFromWkt(ds.GetProjection()) + srs_lat_lon = srs.CloneGeogCS() + ct = osr.CoordinateTransformation(srs_lat_lon, srs) + new_location = [None, None] + # Change the point locations into the GeoTransform space + (new_location[1], new_location[0], holder) = ct.TransformPoint(location[1], location[0]) + # Translate the x and y coordinates into pixel values + Xp = new_location[0] + Yp = new_location[1] + dGeoTrans = gt + dTemp = dGeoTrans[1] * dGeoTrans[5] - dGeoTrans[2] * dGeoTrans[4] + Xpixel = (dGeoTrans[5] * (Xp - dGeoTrans[0]) - dGeoTrans[2] * (Yp - dGeoTrans[3])) / dTemp + Yline = (dGeoTrans[1] * (Yp - dGeoTrans[3]) - dGeoTrans[4] * (Xp - dGeoTrans[0])) / dTemp + del raster_dataset + return (Xpixel, Yline) + + @staticmethod + def write_img(filename, im_proj, im_geotrans, im_data, no_data='0'): + """ + 影像保存 + :param filename: 保存的路径 + :param im_proj: + :param im_geotrans: + :param im_data: + :param no_data: 把无效值设置为 nodata + :return: + """ + + gdal_dtypes = { + 'int8': gdal.GDT_Byte, + 'unit16': gdal.GDT_UInt16, + 'int16': gdal.GDT_Int16, + 'unit32': gdal.GDT_UInt32, + 'int32': gdal.GDT_Int32, + 'float32': gdal.GDT_Float32, + 'float64': gdal.GDT_Float64, + } + if not gdal_dtypes.get(im_data.dtype.name, None) is None: + datatype = gdal_dtypes[im_data.dtype.name] + else: + datatype = gdal.GDT_Float32 + flag = False + # 判读数组维数 + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + flag = True + else: + im_bands, (im_height, im_width) = 1, im_data.shape + + # 创建文件 + if os.path.exists(os.path.split(filename)[0]) is False: + os.makedirs(os.path.split(filename)[0]) + + driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 + dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) + + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + + dataset.SetProjection(im_proj) # 写入投影 + + if im_bands == 1: + # outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据 + if flag: + outband = dataset.GetRasterBand(1) + outband.WriteArray(im_data[0]) + if no_data != 'null': + outband.SetNoDataValue(np.double(no_data)) + outband.FlushCache() + else: + outband = dataset.GetRasterBand(1) + outband.WriteArray(im_data) + if no_data != 'null': + outband.SetNoDataValue(np.double(no_data)) + outband.FlushCache() + else: + for i in range(im_bands): + outband = dataset.GetRasterBand(1 + i) + outband.WriteArray(im_data[i]) + if no_data != 'null': + outband.SetNoDataValue(np.double(no_data)) + outband.FlushCache() + # outRaster.GetRasterBand(i + 1).WriteArray(array[i]) + del dataset + + # 写GeoTiff文件 + + @staticmethod + def write_img_envi(filename, im_proj, im_geotrans, im_data, no_data='null'): + """ + 影像保存 + :param filename: 保存的路径 + :param im_proj: + :param im_geotrans: + :param im_data: + :param no_data: 把无效值设置为 nodata + :return: + """ + + gdal_dtypes = { + 'int8': gdal.GDT_Byte, + 'unit16': gdal.GDT_UInt16, + 'int16': gdal.GDT_Int16, + 'unit32': gdal.GDT_UInt32, + 'int32': gdal.GDT_Int32, + 'float32': gdal.GDT_Float32, + 'float64': gdal.GDT_Float64, + } + if not gdal_dtypes.get(im_data.dtype.name, None) is None: + datatype = gdal_dtypes[im_data.dtype.name] + else: + datatype = gdal.GDT_Float32 + + # 判读数组维数 + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + else: + im_bands, (im_height, im_width) = 1, im_data.shape + + # 创建文件 + if os.path.exists(os.path.split(filename)[0]) is False: + os.makedirs(os.path.split(filename)[0]) + + driver = gdal.GetDriverByName("ENVI") # 数据类型必须有,因为要计算需要多大内存空间 + dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) + + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + + dataset.SetProjection(im_proj) # 写入投影 + + if im_bands == 1: + # outRaster.GetRasterBand(1).WriteArray(array) # 写入数组数据 + outband = dataset.GetRasterBand(1) + outband.WriteArray(im_data) + if no_data != 'null': + outband.SetNoDataValue(no_data) + outband.FlushCache() + else: + for i in range(im_bands): + outband = dataset.GetRasterBand(1 + i) + outband.WriteArray(im_data[i]) + outband.FlushCache() + # outRaster.GetRasterBand(i + 1).WriteArray(array[i]) + del dataset + + @staticmethod + def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict): + """ + 图像中写入rpc信息 + """ + # 判断栅格数据的数据类型 + if 'int8' in im_data.dtype.name: + datatype = gdal.GDT_Byte + elif 'int16' in im_data.dtype.name: + datatype = gdal.GDT_Int16 + else: + datatype = gdal.GDT_Float32 + + # 判读数组维数 + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + else: + im_bands, (im_height, im_width) = 1, im_data.shape + + # 创建文件 + driver = gdal.GetDriverByName("GTiff") + dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) + + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + dataset.SetProjection(im_proj) # 写入投影 + + # 写入RPC参数 + for k in rpc_dict.keys(): + dataset.SetMetadataItem(k, rpc_dict[k], 'RPC') + + if im_bands == 1: + dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 + else: + for i in range(im_bands): + dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) + + del dataset + + + def transtif2mask(self,out_tif_path, in_tif_path, threshold): + """ + :param out_tif_path:输出路径 + :param in_tif_path:输入的路径 + :param threshold:阈值 + """ + im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path) + im_arr_mask = (im_arr < threshold).astype(int) + self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask) + + def write_quick_view(self, tif_path, color_img=False, quick_view_path=None): + """ + 生成快视图,默认快视图和影像同路径且同名 + :param tif_path:影像路径 + :param color_img:是否生成随机伪彩色图 + :param quick_view_path:快视图路径 + """ + if quick_view_path is None: + quick_view_path = os.path.splitext(tif_path)[0]+'.jpg' + + n = self.get_bands(tif_path) + if n == 1: # 单波段 + t_data = self.get_data(tif_path) + else: # 多波段,转为强度数据 + t_data = self.get_data(tif_path) + t_data = t_data.astype(float) + t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2) + t_data[np.isnan(t_data)] = 0 + t_data[np.where(t_data == -9999)] = 0 + t_r = self.get_img_height(tif_path) + t_c = self.get_img_width(tif_path) + if t_r > 10000 or t_c > 10000: + q_r = int(t_r / 10) + q_c = int(t_c / 10) + elif 1024 < t_r < 10000 or 1024 < t_c < 10000: + if t_r > t_c: + q_r = 1024 + q_c = int(t_c/t_r * 1024) + else: + q_c = 1024 + q_r = int(t_r/t_c * 1024) + else: + q_r = t_r + q_c = t_c + + if color_img is True: + # 生成伪彩色图 + img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度) + u = np.unique(t_data) + for i in u: + if i != 0: + w = np.where(t_data == i) + img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围 + img[w[0], w[1], 1] = np.random.randint(0, 255) + img[w[0], w[1], 2] = np.random.randint(0, 255) + + img = cv2.resize(img, (q_c, q_r)) # (宽,高) + cv2.imwrite(quick_view_path, img) + # cv2.imshow("result4", img) + # cv2.waitKey(0) + else: + # 灰度图 + min = np.percentile(t_data, 2) # np.nanmin(t_data) + max = np.percentile(t_data, 98) # np.nanmax(t_data) + # if (max - min) < 256: + t_data = (t_data - min) / (max - min) * 255 + out_img = Image.fromarray(t_data) + out_img = out_img.resize((q_c, q_r)) # 重采样 + out_img = out_img.convert("L") # 转换成灰度图 + out_img.save(quick_view_path) + + def get_center_scopes(self, dataset): + if dataset is None: + return None + + img_geotrans = dataset.GetGeoTransform() + if img_geotrans is None: + return None + + width = dataset.RasterXSize # 栅格矩阵的列数 + height = dataset.RasterYSize # 栅格矩阵的行数 + + x_split = int(width/5) + y_split = int(height/5) + img_col_start = x_split * 1 + img_col_end = x_split * 3 + img_row_start = y_split * 1 + img_row_end = y_split *3 + cols = img_col_end - img_col_start + rows = img_row_end - img_row_start + if cols > 10000 or rows > 10000: + img_col_end = img_col_start + 10000 + img_row_end = img_row_start + 10000 + + point_upleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_start) + point_upright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_start) + point_downleft = self.trans_rowcol2geo(img_geotrans, img_col_start, img_row_end) + point_downright = self.trans_rowcol2geo(img_geotrans, img_col_end, img_row_end) + + return [point_upleft, point_upright, point_downleft, point_downright] + def write_view(self, tif_path, color_img=False, quick_view_path=None): + """ + 生成快视图,默认快视图和影像同路径且同名 + :param tif_path:影像路径 + :param color_img:是否生成随机伪彩色图 + :param quick_view_path:快视图路径 + """ + if quick_view_path is None: + quick_view_path = os.path.splitext(tif_path)[0]+'.jpg' + + n = self.get_bands(tif_path) + if n == 1: # 单波段 + t_data = self.get_data(tif_path) + else: # 多波段,转为强度数据 + t_data = self.get_data(tif_path) + t_data = t_data.astype(float) + t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2) + t_data[np.isnan(t_data)] = 0 + t_data[np.where(t_data == -9999)] = 0 + t_r = self.get_img_height(tif_path) + t_c = self.get_img_width(tif_path) + q_r = t_r + q_c = t_c + + if color_img is True: + # 生成伪彩色图 + img = np.zeros((t_r, t_c, 3), dtype=np.uint8) # (高,宽,维度) + u = np.unique(t_data) + for i in u: + if i != 0: + w = np.where(t_data == i) + img[w[0], w[1], 0] = np.random.randint(0, 255) # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围 + img[w[0], w[1], 1] = np.random.randint(0, 255) + img[w[0], w[1], 2] = np.random.randint(0, 255) + + img = cv2.resize(img, (q_c, q_r)) # (宽,高) + cv2.imwrite(quick_view_path, img) + # cv2.imshow("result4", img) + # cv2.waitKey(0) + else: + # 灰度图 + min = np.percentile(t_data, 2) # np.nanmin(t_data) + max = np.percentile(t_data, 98) # np.nanmax(t_data) + # if (max - min) < 256: + t_data = (t_data - min) / (max - min) * 255 + out_img = Image.fromarray(t_data) + out_img = out_img.resize((q_c, q_r)) # 重采样 + out_img = out_img.convert("L") # 转换成灰度图 + out_img.save(quick_view_path) + + return quick_view_path + + def limit_field(self, out_path, in_path, min_value, max_value): + """ + :param out_path:输出路径 + :param in_path:主mask路径,输出影像采用主mask的地理信息 + :param min_value + :param max_value + """ + proj = self.get_projection(in_path) + geotrans = self.get_geotransform(in_path) + array = self.get_band_array(in_path, 1) + array[array < min_value] = min_value + array[array > max_value] = max_value + self.write_img(out_path, proj, geotrans, array) + return True + + def band_merge(self, lon, lat, ori_sim): + lon_arr = self.get_data(lon) + lat_arr = self.get_data(lat) + temp = np.zeros((2, lon_arr.shape[0], lon_arr.shape[1]), dtype=float) + temp[0, :, :] = lon_arr[:, :] + temp[1, :, :] = lat_arr[:, :] + self.write_img(ori_sim, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], temp, '0') + + + def get_scopes(self, ori_sim): + ori_sim_data = self.get_data(ori_sim) + lon = ori_sim_data[0, :, :] + lat = ori_sim_data[1, :, :] + + min_lon = np.nanmin(np.where((lon != 0) & ~np.isnan(lon), lon, np.inf)) + max_lon = np.nanmax(np.where((lon != 0) & ~np.isnan(lon), lon, -np.inf)) + min_lat = np.nanmin(np.where((lat != 0) & ~np.isnan(lat), lat, np.inf)) + max_lat = np.nanmax(np.where((lat != 0) & ~np.isnan(lat), lat, -np.inf)) + + scopes = [[min_lon, max_lat], [max_lon, max_lat], [min_lon, min_lat], [max_lon, min_lat]] + return scopes + + @staticmethod + def dem_merged(in_dem_path, out_dem_path): + ''' + DEM重采样函数,默认坐标系为WGS84 + agrs: + in_dem_path: 输入的DEM文件夹路径 + meta_file_path: 输入的xml元文件路径 + out_dem_path: 输出的DEM文件夹路径 + ''' + # 读取文件夹中所有的DEM + dem_file_paths = [os.path.join(in_dem_path, dem_name) for dem_name in os.listdir(in_dem_path) if + dem_name.find(".tif") >= 0 and dem_name.find(".tif.") == -1] + spatialreference = osr.SpatialReference() + spatialreference.SetWellKnownGeogCS("WGS84") # 设置地理坐标,单位为度 degree # 设置投影坐标,单位为度 degree + spatialproj = spatialreference.ExportToWkt() # 导出投影结果 + # 将DEM拼接成一张大图 + mergeFile = gdal.BuildVRT(os.path.join(out_dem_path, "mergedDEM_VRT.tif"), dem_file_paths) + out_DEM = os.path.join(out_dem_path, "mergedDEM.tif") + gdal.Warp(out_DEM, + mergeFile, + format="GTiff", + dstSRS=spatialproj, + dstNodata=-9999, + outputType=gdal.GDT_Float32) + time.sleep(3) + # gdal.CloseDir(out_DEM) + return out_DEM + + @staticmethod + def landcover_merged(in_dem_path, out_dem_path): + ''' + DEM重采样函数,默认坐标系为WGS84 + agrs: + in_dem_path: 输入的DEM文件夹路径 + meta_file_path: 输入的xml元文件路径 + out_dem_path: 输出的DEM文件夹路径 + ''' + # 读取文件夹中所有的DEM + dem_file_paths = [os.path.join(in_dem_path, dem_name) for dem_name in os.listdir(in_dem_path) if + dem_name.find(".tif") >= 0 and dem_name.find(".tif.") == -1] + spatialreference = osr.SpatialReference() + spatialreference.SetWellKnownGeogCS("WGS84") # 设置地理坐标,单位为度 degree # 设置投影坐标,单位为度 degree + spatialproj = spatialreference.ExportToWkt() # 导出投影结果 + # 将DEM拼接成一张大图 + mergeFile = gdal.BuildVRT(os.path.join(out_dem_path, "mergedDEM_VRT.tif"), dem_file_paths) + out_DEM = os.path.join(out_dem_path, "mergedDEM.tif") + gdal.Warp(out_DEM, + mergeFile, + format="GTiff", + dstSRS=spatialproj, + dstNodata=-9999, + outputType=gdal.GDT_Byte) + time.sleep(3) + # gdal.CloseDir(out_DEM) + return out_DEM + + @staticmethod + def get_inc_angle(inc_xml, rows, cols, out_path): + tree = ElementTree() + tree.parse(inc_xml) # 影像头文件 + root = tree.getroot() + values = root.findall('incidenceValue') + angle_value = [value.text for value in values] + angle_value = np.array(angle_value) + inc_angle = np.tile(angle_value, (rows, 1)) + ImageHandler.write_img(out_path, '', [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], inc_angle) + + pass + + +if __name__ == '__main__': + cols = 7086 + rows = 8064 + inc_xml = r"D:\micro\WorkSpace\Dem\Temporary\processing\product\GF3_SAY_FSI_001614_E113.2_N34.5_20161129_L1A_HHHV_L10002015686-DEM.tiff" + # ImageHandler().write_quick_view(inc_xml) + # fn = r'E:\202306hb\result\GF3B_SYC_QPSI_008316_E116.1_N43.3_20230622_L1A_AHV_L10000202892-cal-SMC.tif' + # out = r'E:\202306hb\result\soil.tif' + # # + # im_proj, im_geotrans, im_arr = ImageHandler.read_img(fn) + # im_arr[np.where(np.isnan(im_arr))] = 0 + # # h,w = im_arr.shape + # # arr = np.random.rand(h,w)*0.4 + # ImageHandler.write_img(out, im_proj, im_geotrans, im_arr, '-1') + +# path = r'D:\BaiduNetdiskDownload\GZ\lon.rdr' +# path2 = r'D:\BaiduNetdiskDownload\GZ\lat.rdr' +# path3 = r'D:\BaiduNetdiskDownload\GZ\lon_lat.tif' +# s = ImageHandler().band_merge(path, path2, path3) +# print(s) +# pass + fileDir = r'D:\micro\WorkSpace\ortho\Temporary\VH_db' + outDir = r'D:\micro\WorkSpace\ortho\Temporary\VH_db\test' + files = os.listdir(fileDir) + for file in files: + tifFile = os.path.join(fileDir, file) + outFile = os.path.join(outDir, file) + im_proj, im_geotrans, im_arr = ImageHandler.read_img(tifFile) + im_arr[np.isnan(im_arr)] = 0 + im_arr[np.isinf(im_arr)] = 0 + ImageHandler.write_img(outFile, im_proj, im_geotrans, im_arr, '0') diff --git a/tools/SpacetySliceDataTools/DataSampleSliceRaster.py b/tools/SpacetySliceDataTools/DataSampleSliceRaster.py new file mode 100644 index 0000000..4a781a1 --- /dev/null +++ b/tools/SpacetySliceDataTools/DataSampleSliceRaster.py @@ -0,0 +1,508 @@ +import os +import argparse +from osgeo import ogr,gdal +from matplotlib import pyplot as plt +from osgeo import gdal +import matplotlib +from matplotlib import pyplot as plt +import matplotlib.patches as patches +from osgeo import gdal +from PIL import Image +from scipy.spatial import cKDTree +import numpy as np +from DotaOperator import DotaObj,createDota,readDotaFile,writerDotaFile +import argparse +import math +from math import ceil, floor +########################################################################## +# 函数区 +########################################################################## +def read_tif(path): + dataset = gdal.Open(path) # 打开TIF文件 + if dataset is None: + print("无法打开文件") + return None, None, None + + cols = dataset.RasterXSize # 图像宽度 + rows = dataset.RasterYSize # 图像高度 + bands = dataset.RasterCount + im_proj = dataset.GetProjection() # 获取投影信息 + im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 + im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 + print("行数:", rows) + print("列数:", cols) + print("波段:", bands) + del dataset # 关闭数据集 + return im_proj, im_Geotrans, im_data + +def write_envi(im_data, im_geotrans, im_proj, output_path): + """ + 将数组数据写入ENVI格式文件 + :param im_data: 输入的numpy数组(2D或3D) + :param im_geotrans: 仿射变换参数(6元组) + :param im_proj: 投影信息(WKT字符串) + :param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr) + """ + im_bands = 1 + im_height, im_width = im_data.shape + # 创建ENVI格式驱动 + driver = gdal.GetDriverByName("ENVI") + dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte) + + if dataset is not None: + dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数 + dataset.SetProjection(im_proj) # 设置投影 + + dataset.GetRasterBand(1).WriteArray(im_data) + + dataset.FlushCache() # 确保数据写入磁盘 + dataset = None # 关闭文件 + +def geoXY2pixelXY(geo_x, geo_y, inv_gt): + pixel_x = inv_gt[0] + geo_x * inv_gt[1] + geo_y * inv_gt[2] + pixel_y = inv_gt[3] + geo_x * inv_gt[4] + geo_y * inv_gt[5] + return pixel_x, pixel_y + +def label2pixelpoints(dotapath,tiff_inv_trans,methodstr): + dotalist = readDotaFile(dotapath) + if methodstr=="geolabel": + for i in range(len(dotalist)): + geo_x = dotalist[i].x1 # x1 + geo_y = dotalist[i].y1 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x1 = pixel_x + dotalist[i].y1 = pixel_y + + geo_x = dotalist[i].x2 # x2 + geo_y = dotalist[i].y2 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x2 = pixel_x + dotalist[i].y2 = pixel_y + + geo_x = dotalist[i].x3 # x3 + geo_y = dotalist[i].y3 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x3 = pixel_x + dotalist[i].y3 = pixel_y + + geo_x = dotalist[i].x4 # x4 + geo_y = dotalist[i].y4 + pixel_x, pixel_y = geoXY2pixelXY(geo_x, geo_y, tiff_inv_trans) + dotalist[i].x4 = pixel_x + dotalist[i].y4 = pixel_y + + print("点数:", len(dotalist)) + return dotalist + +def getMaxEdge(dotalist, ids): + cornpoint = np.zeros((len(ids) * 4, 2)) + for idx in range(len(ids)): + cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1 + cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2 + cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3 + cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4 + + cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1 + cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2 + cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3 + cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4 + + xedge = np.max(cornpoint[:, 0]) - np.min(cornpoint[:, 0]) + yedge = np.max(cornpoint[:, 1]) - np.min(cornpoint[:, 1]) + + edgelen = xedge if xedge > yedge else yedge + return edgelen + +def getExternCenter(dotalist, ids): + cornpoint = np.zeros((len(ids) * 4, 2)) + for idx in range(len(ids)): + cornpoint[idx * 4 + 0, 0] = dotalist[ids[idx]].x1 + cornpoint[idx * 4 + 1, 0] = dotalist[ids[idx]].x2 + cornpoint[idx * 4 + 2, 0] = dotalist[ids[idx]].x3 + cornpoint[idx * 4 + 3, 0] = dotalist[ids[idx]].x4 + + cornpoint[idx * 4 + 0, 1] = dotalist[ids[idx]].y1 + cornpoint[idx * 4 + 1, 1] = dotalist[ids[idx]].y2 + cornpoint[idx * 4 + 2, 1] = dotalist[ids[idx]].y3 + cornpoint[idx * 4 + 3, 1] = dotalist[ids[idx]].y4 + + minX = np.min(cornpoint[:, 0]) + minY = np.min(cornpoint[:, 1]) + maxX = np.max(cornpoint[:, 0]) + maxY = np.max(cornpoint[:, 1]) + centerX = (minX + maxX) / 2 + centerY = (minY + maxY) / 2 + return [centerX, centerY, minX, minY, maxX, maxY] + +def drawSliceRasterPrivew(tiff_data,dotalist,clusterDict): + # 绘制图形 + # 创建图形和坐标轴 + fig, ax = plt.subplots(figsize=(20, 16)) + ax.imshow(tiff_data, cmap='gray') + # 绘制每个目标的矩形框并标注坐标 + for i in range(len(dotalist)): + # 提取x和y坐标 + x_coords = [dotalist[i].x1, dotalist[i].x2, dotalist[i].x3, dotalist[i].x4] + y_coords = [dotalist[i].y1, dotalist[i].y2, dotalist[i].y3, dotalist[i].y4] + + # 计算最小外接矩形(AABB) + x_min, x_max = min(x_coords), max(x_coords) + y_min, y_max = min(y_coords), max(y_coords) + width = x_max - x_min + height = y_max - y_min + + # 绘制无填充矩形框(仅红色边框) + rect = patches.Rectangle( + (x_min, y_min), width, height, + linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none' + ) + ax.add_patch(rect) + + # ax.annotate(f'({x},{y})', xy=(x, y), xytext=(5, 5), + # textcoords='offset points', fontsize=10, + # bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8)) + + # 在矩形中心标注目标编号 + center_x = sum(x_coords) / 4 + center_y = sum(y_coords) / 4 + ax.text(center_x, center_y, str(i), + ha='center', va='center', fontsize=6, color='red') + + # 以类别中心为中心绘制四边形 + for k in clusterDict: + # 绘制无填充矩形框(仅红色边框) + minX = clusterDict[k]["p"][0] + minY = clusterDict[k]["p"][1] + rect = patches.Rectangle( + (minX , minY), 1024, 1024, + linewidth=2, edgecolor='green', facecolor='none' # 关键:facecolor='none' + ) + ax.add_patch(rect) + ax.text(minX+512, minY+512, str(k), + ha='center', va='center', fontsize=6, color='green') + + plt.tight_layout() + plt.show() + + print("绘图结束") + return None + +def find_optimal_slices(H, W, boxes, patch_size=1024, max_overlap_rate=0.2): + """ + Compute optimal slice positions for the image to maximize the number of fully contained rectangular patches (boxes), + while ensuring the overlap rate between any two slices does not exceed the specified maximum. + + Parameters: + - H: Height of the image. + - W: Width of the image. + - boxes: List of tuples or lists, each containing (x1, y1, x2, y2) where (x1, y1) is the top-left and (x2, y2) is the bottom-right of a rectangular patch. + - patch_size: Size of each slice (square, e.g., 1024). + - max_overlap_rate: Maximum allowed overlap rate (e.g., 0.2). + + Returns: + - slices: List of (sx, sy) starting positions for the slices. + - covered_count: Number of patches that appear fully in at least one slice. + """ + overlap_max = patch_size * max_overlap_rate + stride = patch_size - floor(overlap_max) # Ensures overlap <= max_overlap_rate in linear dimensions + + N = len(boxes) + x_covered = [set() for _ in range(stride)] + y_covered = [set() for _ in range(stride)] + + for i in range(N): + x1, y1, x2, y2 = boxes[i] + b_w = x2 - x1 + b_h = y2 - y1 + + # For x-dimension + lx = max(0, x2 - patch_size) + rx = min(W - patch_size, x1) + if lx <= rx: + start_x = ceil(lx) + end_x = floor(rx) + l_x = end_x - start_x + 1 + if l_x > 0: + if l_x >= stride: + for ox in range(stride): + x_covered[ox].add(i) + else: + for sx in range(start_x, end_x + 1): + ox = sx % stride + x_covered[ox].add(i) + + # For y-dimension + ly = max(0, y2 - patch_size) + ry = min(H - patch_size, y1) + if ly <= ry: + start_y = ceil(ly) + end_y = floor(ry) + l_y = end_y - start_y + 1 + if l_y > 0: + if l_y >= stride: + for oy in range(stride): + y_covered[oy].add(i) + else: + for sy in range(start_y, end_y + 1): + oy = sy % stride + y_covered[oy].add(i) + + # Find the best offset pair (ox, oy) that maximizes covered patches + max_covered = 0 + best_ox = 0 + best_oy = 0 + for ox in range(stride): + for oy in range(stride): + current_covered = len(x_covered[ox] & y_covered[oy]) + if current_covered > max_covered: + max_covered = current_covered + best_ox = ox + best_oy = oy + + # Generate the slice positions using the best offsets and stride + slices = [] + sx = best_ox + while sx + patch_size <= W: + sy = best_oy + while sy + patch_size <= H: + slices.append((sx, sy)) + sy += stride + sx += stride + + return slices, max_covered + +def check_B_in_A(A,B): + """ + 判断A包含B + :param A: [x0,y0.w.h] + :param B: [x0,y0.w.h] + :return: + """ + # 解构矩形A和B的参数 + Ax0, Ay0, Aw, Ah = A + Bx0, By0, Bw, Bh = B + + # 计算矩形A和B的右边界和下边界 + Ax1 = Ax0 + Aw + Ay1 = Ay0 + Ah + Bx1 = Bx0 + Bw + By1 = By0 + Bh + + # 判断B是否完全在A内部 + return (Bx0 >= Ax0) and (Bx1 <= Ax1) and (By0 >= Ay0) and (By1 <= Ay1) + + +########################################################################## +# 切分算法流程图 +########################################################################## + +def getclusterDict(dotalist,imgheight,imgwidth,pitchSize=1024,max_overlap_rate=0.2): + """ + 生成切片数据 + :param dotalist: 样本集 + :param imgheight: 图像高度 + :param imgwidth: 图像宽度 + :return: 切片类型 + """ + boxs=[] + for i in range(len(dotalist)): + xs=np.array([dotalist[i].x1,dotalist[i].x2,dotalist[i].x3, dotalist[i].x4]) + ys=np.array([dotalist[i].y1,dotalist[i].y2,dotalist[i].y3, dotalist[i].y4]) + x1=np.min(xs) + x2=np.max(xs) + y1=np.min(ys) + y2=np.max(ys) + boxs.append([x1,y1,x2,y2]) # x1, y1, x2, y2 = boxes[i] + + slices, max_covered=find_optimal_slices(imgheight,imgwidth,boxs,pitchSize,max_overlap_rate) + + clusterDict={} + + waitContaindota=[] + hasContainIds=[] + for i in range(len(slices)): + sx,sy=slices[i] + clusterDict[i]={"p":[sx,sy],"id":[]} + slicesExten=[sx,sy,1024,1024] + for ids in range(len(dotalist)): + if ids in hasContainIds: + continue + else: + [centerX, centerY, minX, minY, maxX, maxY]=getExternCenter(dotalist, [ids]) + dotaExtend=[minX,minY,maxX-minX,maxY-minY] + if check_B_in_A(slicesExten,dotaExtend): + clusterDict[i]["id"].append(ids) + hasContainIds.append(ids) + + for ids in range(len(dotalist)): + if ids in hasContainIds: + continue + else: + waitContaindota.append(ids) + print("No in slice dota : ",str(dotalist[ids]) ) + + print("no process ids ",str(waitContaindota)) + return clusterDict + + +def drawSlictplot(clusterDict,dotalist,tiff_data,nrows=10,ncols=9): + """ + :param clusterDict: clusterDict[i]={"p":[sx,sy],"id":[]} + :param dotalist: (x1, y1, x2, y2, x3, y3, x4, y4 clsname diffcule) + :return: + """ + fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(20, 16)) + plt.tight_layout(pad=3.0) + + # 9*10 + subid=0 + for cid in clusterDict: + sx,sy=clusterDict[cid]["p"] + colid=subid//nrows + rowid=subid%nrows + subid=subid+1 + ax = axes[rowid, colid] + ax.set_title(str(cid)) + sliceData=tiff_data[sy:(sy+1024),sx:(sx+1024)] + ax.imshow(sliceData, cmap='gray') + + for did in clusterDict[cid]["id"] : + # 提取x和y坐标 + x_coords = [dotalist[did].x1-sx, dotalist[did].x2-sx, dotalist[did].x3-sx, dotalist[did].x4-sx] + y_coords = [dotalist[did].y1-sy, dotalist[did].y2-sy, dotalist[did].y3-sy, dotalist[did].y4-sy] + + # 计算最小外接矩形(AABB) + x_min, x_max = min(x_coords), max(x_coords) + y_min, y_max = min(y_coords), max(y_coords) + width = x_max - x_min + height = y_max - y_min + + + # 绘制无填充矩形框(仅红色边框) + rect = patches.Rectangle( + (x_min, y_min), width, height, + linewidth=2, edgecolor='red', facecolor='none' # 关键:facecolor='none' + ) + ax.add_patch(rect) + + # 在矩形中心标注目标编号 + center_x = x_min+width/2 + center_y = y_min+height/2 + ax.text(center_x, center_y, str(did), + ha='center', va='center', fontsize=6, color='red') + + + plt.tight_layout() + plt.show() + + print("绘图结束") + return None + + +def slictDataAndOutlabel(clusterDict,dotalist,tiff_data,tiff_basename,outfolderpath,im_geotrans, im_proj): + """ + 切分标签,输出结果与文件 + :param clusterDict: + :param dotalist: + :param tiff_data: + :param tiff_name: + :param outfolderpath: + :return: + """ + for cid in clusterDict: + sx, sy = clusterDict[cid]["p"] + sliceData = tiff_data[sy:(sy + 1024), sx:(sx + 1024)] + outbinname="{}_{}.bin".format(tiff_basename,cid) + outlabelname="{}_{}.txt".format(tiff_basename,cid) + # 获取样本列表 + outdotalist=[] + for did in clusterDict[cid]["id"] : + tempdota=dotalist[did] + tempdota.x1=tempdota.x1-sx + tempdota.x2=tempdota.x2-sx + tempdota.x3=tempdota.x3-sx + tempdota.x4=tempdota.x4-sx + tempdota.y1=tempdota.y1-sy + tempdota.y2=tempdota.y2-sy + tempdota.y3=tempdota.y3-sy + tempdota.y4=tempdota.y4-sy + outdotalist.append(tempdota) + + outlabelpath=os.path.join(outfolderpath,outlabelname) + outbinpath=os.path.join(outfolderpath,outbinname) + + temp_im_geotrans=[tempi for tempi in im_geotrans] + # 处理 0,3 + temp_im_geotrans[0]=im_geotrans[0]+sx*im_geotrans[1]+im_geotrans[2]*sy # x + temp_im_geotrans[3]=im_geotrans[3]+sx*im_geotrans[4]+im_geotrans[5]*sy # y + write_envi(sliceData,temp_im_geotrans,im_proj,outbinpath) + writerDotaFile(outdotalist,outlabelpath) + +########################################################################## +# 处理流程图 +########################################################################## +def DataSampleSliceRasterProcess(inbinfile,labelfilepath,outfolderpath,methodstr): + tiff_proj, tiff_trans, tiff_data = read_tif(inbinfile) + tiff_inv_trans = gdal.InvGeoTransform(tiff_trans) + dotalist=label2pixelpoints(labelfilepath,tiff_inv_trans,methodstr) + imgheight, imgwidth=tiff_data.shape + clusterDict=getclusterDict(dotalist,imgheight,imgwidth) + drawSliceRasterPrivew(tiff_data, dotalist, clusterDict) + ncols=int(len(clusterDict)/9+1) + drawSlictplot(clusterDict, dotalist, tiff_data, 9, ncols) + tiff_name=os.path.basename(inbinfile) + tiff_basename=os.path.splitext(tiff_name)[0] + slictDataAndOutlabel(clusterDict, dotalist, tiff_data, tiff_basename, outfolderpath, tiff_trans, tiff_proj) + + pass + + + +def getParams(): + parser = argparse.ArgumentParser() + parser.add_argument('-i','--inbinfile',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.bin', help='输入tiff的bin文件') + parser.add_argument('-l', '--labelfilepath',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\标注\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01_LC.txt", help='输入标注') + parser.add_argument('-o', '--outfolderpath',type=str,default=r'F:\天仪SAR卫星数据集\舰船数据\切分3', help='切片文件夹地址') + group = parser.add_mutually_exclusive_group() + group.add_argument( + '--geolabel', + action='store_const', + const='geolabel', + dest='method', + help='标注坐标点为地理坐标' + ) + group.add_argument( + '--pixellabel', + action='store_const', + const='pixellabel', + dest='method', + help='标注坐标系统为输入影像的像空间坐标' + ) + parser.set_defaults(method='geolabel') + args = parser.parse_args() + return args + +if __name__ == '__main__': + parser = getParams() + inbinfile=parser.inbinfile + labelfilepath=parser.labelfilepath + outfolderpath=parser.outfolderpath + methodstr= parser.method + print('inbinfile=',inbinfile) + print('labelfilepath=',labelfilepath) + print('outfolderpath=',outfolderpath) + print('methodstr=',methodstr) + DataSampleSliceRasterProcess(inbinfile, labelfilepath, outfolderpath,methodstr) + print("样本切分完成") + + + + + + + + + + + + diff --git a/tools/SpacetySliceDataTools/DotaOperator.py b/tools/SpacetySliceDataTools/DotaOperator.py new file mode 100644 index 0000000..bd8fd6b --- /dev/null +++ b/tools/SpacetySliceDataTools/DotaOperator.py @@ -0,0 +1,80 @@ + +import os +import argparse + +""" +Dota数据集标注 +每条标注数据应包含: ++前 8 个数值,表示目标的四个角点的坐标(x1, y1, x2, y2, x3, y3, x4, y4),按顺时针或逆时针顺序排列,无需进行归一化处理; ++倒数第二列为标注的目标类别名称,如:jun_ship等; ++最后一列为识别目标的难易程度(difficulty) + +提供数据集标注 + +""" +class DotaObj(object): + def __init__(self, x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty): + self.x1=x1 + self.y1=y1 + self.x2=x2 + self.y2=y2 + self.x3=x3 + self.y3=y3 + self.x4=x4 + self.y4=y4 + self.clsname=clsname + self.difficulty=difficulty + def __str__(self): + return "{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}".format( + self.x1,self.y1, + self.x2,self.y2, + self.x3,self.y3, + self.x4,self.y4, + self.clsname,self.difficulty + ) + + +def createDota(x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty): + return DotaObj(x1, y1, x2, y2, x3, y3, x4, y4,clsname,difficulty) # 8+2 + + +def readDotaFile(dotafilepath): + content=None + with open(dotafilepath,'r',encoding="utf-8") as fp: + content=fp.read() + contentlines=content.split("\n") + # 逐行分解 + result=[] + for linestr in contentlines: + linestr=linestr.replace("\t"," ") + linestr=linestr.replace(" "," ") + linemetas=linestr.split(" ") + if(len(linemetas)>=10): + x1=float(linemetas[0]) + y1=float(linemetas[1]) + x2=float(linemetas[2]) + y2=float(linemetas[3]) + x3=float(linemetas[4]) + y3=float(linemetas[5]) + x4=float(linemetas[6]) + y4=float(linemetas[7]) + clsname=linemetas[8] + difficulty=linemetas[9] + result.append(createDota(x1, y1, x2, y2, x3, y3, x4, y4, clsname,difficulty)) + else: + print("parse result: ", linestr) + return result + + +def writerDotaFile(dotalist,filepath): + with open(filepath,'a',encoding="utf-8") as fp: + for dota in dotalist: + if isinstance(dota,DotaObj): + fp.write("{}\n".format(str(dota))) + else: + fp.write("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}\n".format( + dota[0],dota[1],dota[2],dota[3],dota[4],dota[5],dota[6],dota[7],dota[8],dota[9] + )) + + + diff --git a/tools/SpacetySliceDataTools/LabelShapefile2Dota.py b/tools/SpacetySliceDataTools/LabelShapefile2Dota.py new file mode 100644 index 0000000..677685c --- /dev/null +++ b/tools/SpacetySliceDataTools/LabelShapefile2Dota.py @@ -0,0 +1,39 @@ +from osgeo import ogr, gdal +import os +import argparse + +def shp_to_geojson(shp_path, geojson_path): + # 设置编码选项 + gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES") + gdal.SetConfigOption("SHAPE_ENCODING", "GBK") + + # 打开Shapefile + src_ds = ogr.Open(shp_path) + src_layer = src_ds.GetLayer(0) + + # 创建输出GeoJSON + driver = ogr.GetDriverByName('GeoJSON') + if os.path.exists(geojson_path): + driver.DeleteDataSource(geojson_path) + dst_ds = driver.CreateDataSource(geojson_path) + dst_layer = dst_ds.CreateLayer('output', src_layer.GetSpatialRef()) + + # 复制字段定义 + dst_layer.CreateFields(src_layer.schema) + + # 复制要素 + for feature in src_layer: + dst_feature = ogr.Feature(dst_layer.GetLayerDefn()) + dst_feature.SetGeometry(feature.GetGeometryRef()) + for j in range(feature.GetFieldCount()): + dst_feature.SetField(j, feature.GetField(j)) + dst_layer.CreateFeature(dst_feature) + + # 清理 + dst_ds = None + src_ds = None + print("转换完成") + + +# 使用示例 +shp_to_geojson('input.shp', 'output.geojson') \ No newline at end of file diff --git a/tools/SpacetySliceDataTools/SpacetyTIFFDataStretch2PNG.py b/tools/SpacetySliceDataTools/SpacetyTIFFDataStretch2PNG.py new file mode 100644 index 0000000..cfdaf31 --- /dev/null +++ b/tools/SpacetySliceDataTools/SpacetyTIFFDataStretch2PNG.py @@ -0,0 +1,220 @@ +from osgeo import ogr, gdal +import os +import argparse +import numpy as np +from PIL import Image + +def get_filename_without_ext(path): + base_name = os.path.basename(path) + if '.' not in base_name or base_name.startswith('.'): + return base_name + return base_name.rsplit('.', 1)[0] + +def read_tif(path): + dataset = gdal.Open(path) # 打开TIF文件 + if dataset is None: + print("无法打开文件") + return None, None, None + + cols = dataset.RasterXSize # 图像宽度 + rows = dataset.RasterYSize # 图像高度 + bands = dataset.RasterCount + im_proj = dataset.GetProjection() # 获取投影信息 + im_Geotrans = dataset.GetGeoTransform() # 获取仿射变换信息 + im_data = dataset.ReadAsArray(0, 0, cols, rows) # 读取栅格数据为NumPy数组 + print("行数:", rows) + print("列数:", cols) + print("波段:", bands) + del dataset # 关闭数据集 + return im_proj, im_Geotrans, im_data + +def write_envi(im_data, im_geotrans, im_proj, output_path): + """ + 将数组数据写入ENVI格式文件 + :param im_data: 输入的numpy数组(2D或3D) + :param im_geotrans: 仿射变换参数(6元组) + :param im_proj: 投影信息(WKT字符串) + :param output_path: 输出文件路径(无需扩展名,会自动生成.dat和.hdr) + """ + im_bands = 1 + im_height, im_width = im_data.shape + # 创建ENVI格式驱动 + driver = gdal.GetDriverByName("ENVI") + dataset = driver.Create(output_path, im_width, im_height, 1, gdal.GDT_Byte) + + if dataset is not None: + dataset.SetGeoTransform(im_geotrans) # 设置地理变换参数 + dataset.SetProjection(im_proj) # 设置投影 + + dataset.GetRasterBand(1).WriteArray(im_data) + + dataset.FlushCache() # 确保数据写入磁盘 + dataset = None # 关闭文件 + + +def Strech_linear(im_data): + im_data_dB=10*np.log10(im_data) + immask=np.isfinite(im_data_dB) + infmask = np.isinf(im_data_dB) + imvail_data=im_data[immask] + im_data_dB=0 + + minvalue=np.nanmin(imvail_data) + maxvalue=np.nanmax(imvail_data) + + infmask = np.isinf(im_data_dB) + im_data[infmask] = minvalue-100 + im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254+1 + im_data=np.clip(im_data,0,255) + return im_data.astype(np.uint8) + +def Strech_linear1(im_data): + im_data_dB = 10 * np.log10(im_data) + immask = np.isfinite(im_data_dB) + infmask = np.isinf(im_data_dB) + imvail_data = im_data[immask] + im_data_dB=0 + + minvalue=np.percentile(imvail_data,1) + maxvalue = np.percentile(imvail_data, 99) + + + im_data[infmask] = minvalue - 100 + im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 + im_data = np.clip(im_data, 0, 255) + + return im_data.astype(np.uint8) + + +def Strech_linear2(im_data): + im_data_dB = 10 * np.log10(im_data) + immask = np.isfinite(im_data_dB) + infmask = np.isinf(im_data_dB) + imvail_data = im_data[immask] + im_data_dB = 0 + + minvalue = np.percentile(imvail_data, 2) + maxvalue = np.percentile(imvail_data, 98) + + im_data[infmask] = minvalue - 100 + im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 + im_data = np.clip(im_data, 0, 255) + + return im_data.astype(np.uint8) + +def Strech_linear5(im_data): + im_data_dB = 10 * np.log10(im_data) + immask = np.isfinite(im_data_dB) + infmask = np.isinf(im_data_dB) + imvail_data = im_data[immask] + im_data_dB = 0 + + minvalue = np.percentile(imvail_data, 5) + maxvalue = np.percentile(imvail_data, 95) + + im_data[infmask] = minvalue - 100 + im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 + im_data = np.clip(im_data, 0, 255) + + return im_data.astype(np.uint8) + +def Strech_SquareRoot(im_data): + im_data=np.sqrt(im_data) + immask = np.isfinite(im_data) + imvail_data = im_data[immask] + + minvalue=np.nanmin(imvail_data) + maxvalue=np.nanmax(imvail_data) + im_data = (im_data - minvalue) / (maxvalue - minvalue) * 254 + 1 + im_data = np.clip(im_data, 0, 255) + return im_data.astype(np.uint8) + +def DataStrech(im_data,strechmethod): + # [,"Linear1","Linear2","Linear5","SquareRoot"] + if strechmethod == "Linear" : + return Strech_linear(im_data) + elif strechmethod == "Linear1": + return Strech_linear1(im_data) + elif strechmethod == "Linear2": + return Strech_linear2(im_data) + elif strechmethod == "Linear5": + return Strech_linear5(im_data) + elif strechmethod == "SquareRoot": + return Strech_SquareRoot(im_data) + else: + return im_data.astype(np.uint8) + + + + + +def stretchProcess(infilepath,outfilepath,strechmethod): + im_proj, im_Geotrans, im_data=read_tif(infilepath) + envifilepath=get_filename_without_ext(outfilepath)+".bin" + envifilepath=os.path.join(os.path.dirname(outfilepath),envifilepath) + im_data = DataStrech(im_data,strechmethod) + im_data = im_data.astype(np.uint8) + write_envi(im_data,im_Geotrans,im_proj,envifilepath) + Image.fromarray(im_data).save(outfilepath,compress_level=0) + print("图像拉伸处理结束") + +def getParams(): + parser = argparse.ArgumentParser() + parser.add_argument('-i','--infile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.tiff", help='输入shapefile文件') + parser.add_argument('-o', '--outfile',type=str,default=r"F:\天仪SAR卫星数据集\舰船数据\bc2-sp-org-vv-20250205t032055-021998-000036-0055ee-01.png", help='输出geojson文件') + group = parser.add_mutually_exclusive_group() + group.add_argument( + '--Linear', + action='store_const', + const='Linear', + dest='method', + help='线性拉伸' + ) + group.add_argument( + '--Linear1prec', + action='store_const', + const='Linear1', + dest='method', + help='1%线性拉伸' + ) + group.add_argument( + '--Linear2prec', + action='store_const', + const='Linear2', + dest='method', + help='2%线性拉伸' + ) + group.add_argument( + '--Linear5prec', + action='store_const', + const='Linear5', + dest='method', + help='5%线性拉伸' + ) + group.add_argument( + '--SquareRoot', + action='store_const', + const='SquareRoot', + dest='method', + help='平方根拉伸' + ) + parser.set_defaults(method='SquareRoot') + + args = parser.parse_args() + return args + +if __name__ == '__main__': + parser = getParams() + intiffPath=parser.infile + outbinPath=parser.outfile + methodstr=parser.method + print('infile=',intiffPath) + print('outfile=',outbinPath) + print('method=',methodstr) + stretchProcess(intiffPath, outbinPath, methodstr) + + + + + + diff --git a/tools/SpacetySliceDataTools/geojson2dota.py b/tools/SpacetySliceDataTools/geojson2dota.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/SpacetySliceDataTools/shapefile2dota.py b/tools/SpacetySliceDataTools/shapefile2dota.py new file mode 100644 index 0000000..afec6f1 --- /dev/null +++ b/tools/SpacetySliceDataTools/shapefile2dota.py @@ -0,0 +1,89 @@ +from osgeo import ogr +import os +import argparse + +def shapefile_to_dota(shp_path, output_path, class_field='class', difficulty_value=1): + """ + 将Shapefile转换为DOTA格式 + :param shp_path: Shapefile文件路径 + :param output_path: 输出目录 + :param class_field: 类别字段名 + :param difficulty_value: 难度默认字段 + """ + + # 注册所有驱动 + ogr.RegisterAll() + + # 打开Shapefile文件 + driver = ogr.GetDriverByName('ESRI Shapefile') + datasource = driver.Open(shp_path, 0) + if datasource is None: + print("无法打开Shapefile文件") + return + + # 获取图层 + layer = datasource.GetLayer() + + output_file = output_path + + with open(output_file, 'w',encoding="utf-8") as f: + # 写入DOTA格式头信息(可选) + # f.write('imagesource:unknown\n') + # f.write('gsd:1.0\n') + + # 遍历所有要素 + for feature in layer: + # 获取几何对象 + geom = feature.GetGeometryRef() + + # 获取类别和难度 + class_name = feature.GetField(class_field) if feature.GetField(class_field) else 'unknown' + difficulty = difficulty_value + + # 处理不同类型的几何图形 + if geom.GetGeometryName() == 'POLYGON': + # 获取多边形外环 + ring = geom.GetGeometryRef(0) + # 获取所有点 + points = [] + for i in range(ring.GetPointCount()): + points.append(ring.GetPoint(i)) + + # 确保有足够的点(至少4个) + if len(points) >= 4: + # 取前4个点作为DOTA格式的四个角点 + # 注意: DOTA要求按顺序排列(顺时针或逆时针) + x1, y1 = points[0][0], points[0][1] + x2, y2 = points[1][0], points[1][1] + x3, y3 = points[2][0], points[2][1] + x4, y4 = points[3][0], points[3][1] + + # 写入DOTA格式行 + line = f"{x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4} {class_name} {difficulty}\n" + f.write(line) + + # 释放资源 + datasource.Destroy() + print("转换完毕") + +def getParams(): + parser = argparse.ArgumentParser() + parser.add_argument('-i','--infile',type=str,default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.shp', help='输入shapefile文件') + parser.add_argument('-o', '--outfile',type=str,default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.txt', help='输出geojson文件') + parser.add_argument('-c', '--clsname',type=str,default=r'Label', help='输出geojson文件') + parser.add_argument('-d', '--difficulty',type=int,default=1, help='输出geojson文件') + args = parser.parse_args() + return args + +if __name__ == '__main__': + parser = getParams() + inFilePath=parser.infile + outpath=parser.outfile + clsname=parser.clsname + difficulty=parser.difficulty + print('infile=',inFilePath) + print('outfile=',outpath) + print('clsname=',clsname) + print('difficulty=',difficulty) + shapefile_to_dota(inFilePath, outpath, clsname, difficulty) +# 使用示例 diff --git a/tools/SpacetySliceDataTools/shapefile2geojson.py b/tools/SpacetySliceDataTools/shapefile2geojson.py new file mode 100644 index 0000000..7083480 --- /dev/null +++ b/tools/SpacetySliceDataTools/shapefile2geojson.py @@ -0,0 +1,51 @@ +from osgeo import ogr, gdal +import os +import argparse + +def shp_to_geojson(shp_path, geojson_path): + # 设置编码选项 + gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES") + gdal.SetConfigOption("SHAPE_ENCODING", "GBK") + + # 打开Shapefile + src_ds = ogr.Open(shp_path) + src_layer = src_ds.GetLayer(0) + + # 创建输出GeoJSON + driver = ogr.GetDriverByName('GeoJSON') + if os.path.exists(geojson_path): + driver.DeleteDataSource(geojson_path) + dst_ds = driver.CreateDataSource(geojson_path) + dst_layer = dst_ds.CreateLayer('output', src_layer.GetSpatialRef()) + + # 复制字段定义 + dst_layer.CreateFields(src_layer.schema) + + # 复制要素 + for feature in src_layer: + dst_feature = ogr.Feature(dst_layer.GetLayerDefn()) + dst_feature.SetGeometry(feature.GetGeometryRef()) + for j in range(feature.GetFieldCount()): + dst_feature.SetField(j, feature.GetField(j)) + dst_layer.CreateFeature(dst_feature) + + # 清理 + dst_ds = None + src_ds = None + print("转换完成") + +def getParams(): + parser = argparse.ArgumentParser() + parser.add_argument('-i','--infile',default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.shp', help='输入shapefile文件') + parser.add_argument('-o', '--outfile',default=r'F:\天仪SAR卫星数据集\德清院-测试-天仪提供数据\bc2-sp-org-vv-20250210t160723-022200-000120-0056b8-01_石佳宁.json', help='输出geojson文件') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + parser = getParams() + inFilePath=parser.infile + outpath=parser.outfile + print('infile=',inFilePath) + print('outfile=',outpath) + shp_to_geojson(inFilePath, outpath) \ No newline at end of file