microproduct/vegetationHeight-L-SAR/VegetationHeightPrePro.py

320 lines
12 KiB
Python
Raw Normal View History

2023-08-28 10:17:29 +00:00
# -*- coding: UTF-8 -*-
"""
@Project microproduct
@File VegetationHeightPrePro.py
@Function @Function: 坐标转换,坐标系转换图像裁剪重投影重采样
@Author LMM
@Date 2021/8/25 14:17
@Version 1.0.0
"""
from shapely.geometry import Polygon # 导入 gdal库要放到这一句的后面不然会引起错误
import logging
# import shapely
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
import os
import shapefile
from shapely.errors import TopologicalError
import numpy as np
from PIL import Image
logger = logging.getLogger("mylog")
env_str = os.getcwd()
os.environ['PROJ_LIB'] = env_str
class PreProcess:
"""
预处理所有的影像配准
"""
def __init__(self):
pass
@staticmethod
def lonlat2geo(lat, lon):
"""
WGS84转平面坐标
Param: lat 为WGS_1984的纬度
Param: lon 为WGS_1984的经度
输出转换后的坐标x,y
"""
dstsrs1 = osr.SpatialReference()
dstsrs1.ImportFromEPSG(32649)
# print("输出投影pro_经纬度",dstSRS1)
dstsrs2 = osr.SpatialReference()
dstsrs2.ImportFromEPSG(4326)
# print("输出投影:",dstSRS2)
ct = osr.CoordinateTransformation(dstsrs2, dstsrs1)
coords = ct.TransformPoint(lat, lon)
# print("输出转换后的坐标x,y:",coords[:2])
return coords[:2]
#
# def transtif2mask(self, out_tif_path, in_tif_path, threshold):
# """
# :param out_tif_path:输出路径
# :param in_tif_path:输入的路径
# :param threshold:阈值
# """
# im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
# im_arr_mask = (im_arr < threshold).astype(int)
# self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)
@staticmethod
def trans_geogcs2projcs(out_path, in_path):
"""
:param out_path:wgs84投影坐标影像保存路径
:param in_path:地理坐标影像输入路径
"""
# 创建文件
if os.path.exists(os.path.split(out_path)[0]) is False:
os.makedirs(os.path.split(out_path)[0])
options = gdal.WarpOptions(format='GTiff', srcSRS='EPSG:4326', dstSRS='EPSG:32649')
gdal.Warp(out_path, in_path, options=options)
@staticmethod
def trans_projcs2geogcs(out_path, in_path):
"""
:param out_path:wgs84投影坐标影像保存路径
:param in_path:地理坐标影像输入路径
"""
# 创建文件
if os.path.exists(os.path.split(out_path)[0]) is False:
os.makedirs(os.path.split(out_path)[0])
options = gdal.WarpOptions(format='GTiff', srcSRS='EPSG:32649', dstSRS='EPSG:4326')
gdal.Warp(out_path, in_path, options=options)
@staticmethod
def intersect_polygon(scopes_tuple):
"""
功能说明计算多边形相交的区域坐标;注意多边形区域会转变成凸区域再求交
:param scopes_tuple: 输入多个区域坐标的tuple
:return: 多边形相交的区域坐标((x0,y0),(x1,y1),..., (xn,yn))
"""
if len(scopes_tuple) < 2:
logger.error('len(scopes_tuple) < 2')
return
try:
# python四边形对象会自动计算四个点最后四个点顺序为左上 左下 右下 右上 左上
tmp = tuple(scopes_tuple[0])
poly_intersect = Polygon(tmp).convex_hull # 计算四点顺序
for i in range(len(scopes_tuple)-1):
polygon_next = Polygon(tuple(scopes_tuple[i+1])).convex_hull
if poly_intersect.intersects(polygon_next): # 相不相交
poly_intersect = poly_intersect.intersection(polygon_next) # 相交面积
else:
msg = 'Image:' + str(i) + 'range does not overlap!'
logger.error(msg)
return
return list(poly_intersect.boundary.coords)[:-1]
except shapely.geos.TopologicalError:
logger.error('shapely.geos.TopologicalError occurred!')
return
@staticmethod
def write_polygon_shp(out_shp_path, point_list):
"""
功能说明创建闭环的矢量文件
:param out_shp_path :矢量文件保存路径
:param point_list :装有闭环点的列表[[x0,y0],[x1,y1]...[xn,yn]]
:return: True or False
"""
# 为了支持中文路径,请添加下面这句代码
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
# 为了使属性表字段支持中文,请添加下面这句
gdal.SetConfigOption("SHAPE_ENCODING", "")
# 注册所有的驱动
ogr.RegisterAll()
# 创建数据这里以创建ESRI的shp文件为例
str_driver_name = "ESRI Shapefile"
o_driver = ogr.GetDriverByName(str_driver_name)
if o_driver is None:
msg = 'driver('+str_driver_name+')is invalid value'
logger.error(msg)
return False
# 创建数据源
if os.path.exists(out_shp_path) and os.path.isfile(out_shp_path): # 如果已存在同名文件
os.remove(out_shp_path) # 则删除之
o_ds = o_driver.CreateDataSource(out_shp_path)
if o_ds is None:
msg = 'create file failed!' + out_shp_path
logger.error(msg)
return False
# 创建图层,创建一个多边形图层
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326) # 计算ztd需要用到经纬度所以此处采用经纬度坐标系
# srs.ImportFromEPSG(32649) # 投影坐标系空间参考WGS84
o_layer = o_ds.CreateLayer("TestPolygon", srs, ogr.wkbPolygon)
if o_layer is None:
msg = 'create coverage failed!'
logger.error(msg)
return False
# 下面创建属性表
# 先创建一个叫FieldID的整型属性
o_field_id = ogr.FieldDefn("FieldID", ogr.OFTInteger)
o_layer.CreateField(o_field_id, 1)
# 再创建一个叫FeatureName的字符型属性字符长度为50
o_field_name = ogr.FieldDefn("FieldName", ogr.OFTString)
o_field_name.SetWidth(100)
o_layer.CreateField(o_field_name, 1)
o_defn = o_layer.GetLayerDefn()
# 创建矩形要素
o_feature_rectangle = ogr.Feature(o_defn)
o_feature_rectangle.SetField(0, 1)
o_feature_rectangle.SetField(1, "IntersectRegion")
# 创建环对象ring
ring = ogr.Geometry(ogr.wkbLinearRing)
for i in range(len(point_list)):
ring.AddPoint(point_list[i][0], point_list[i][1])
ring.CloseRings()
# 创建环对象polygon
geom_rect_polygon = ogr.Geometry(ogr.wkbPolygon)
geom_rect_polygon.AddGeometry(ring)
o_feature_rectangle.SetGeometry(geom_rect_polygon)
o_layer.CreateFeature(o_feature_rectangle)
o_ds.Destroy()
return True
@staticmethod
def cut_img(output_path, input_path, shp_path):
"""
:param output_path:剪切后的影像
:param input_path:待剪切的影像
:param shp_path:矢量数据
:return: True or False
"""
r = shapefile.Reader(shp_path)
box = r.bbox
input_dataset = gdal.Open(input_path)
gdal.Warp(output_path, input_dataset, format='GTiff', outputBounds=box, cutlineDSName=shp_path, dstNodata=-9999)
# cutlineWhere="FIELD = whatever",
# optionally you can filter your cutline (shapefile) based on attribute values
# select the no data value you like
# ds = None
# do other stuff with ds object, it is your cropped dataset. in this case we only close the dataset.
del input_dataset
return True
@staticmethod
def resampling_by_scale(input_path, target_file, refer_img_path):
"""
按照缩放比例对影像重采样
:param input_path: GDAL地理数据路径
:param target_file: 输出影像
:param refer_img_path:参考影像
:return: True or False
"""
ref_dataset = gdal.Open(refer_img_path)
ref_cols = ref_dataset.RasterXSize # 列数
ref_rows = ref_dataset.RasterYSize # 行数
dataset = gdal.Open(input_path) # 判断数据是否存在
if dataset is None:
logger.error('resampling_by_scale:dataset is None!')
return False
band_count = dataset.RasterCount # 波段数
if (band_count == 0) or (target_file == ""):
logger.error("resampling_by_scale:Parameters of the abnormal!")
return False
cols = dataset.RasterXSize # 列数
scale = ref_cols/cols # 参考图像的分辨率
# rows = dataset.RasterYSize # 行数
# cols = int(cols * scale) # 计算新的行列数
# rows = int(rows * scale)
cols = ref_cols
rows = ref_rows
geotrans = list(dataset.GetGeoTransform())
geotrans[1] = geotrans[1] / scale # 像元宽度变为原来的scale倍
geotrans[5] = geotrans[5] / scale # 像元高度变为原来的scale倍
if os.path.exists(target_file) and os.path.isfile(target_file): # 如果已存在同名影像
os.remove(target_file) # 则删除之
if not os.path.exists(os.path.split(target_file)[0]):
os.makedirs(os.path.split(target_file)[0])
band1 = dataset.GetRasterBand(1)
data_type = band1.DataType
target = dataset.GetDriver().Create(target_file, xsize=cols, ysize=rows, bands=band_count,
eType=data_type)
target.SetProjection(dataset.GetProjection()) # 设置投影坐标
target.SetGeoTransform(geotrans) # 设置地理变换参数
total = band_count + 1
for index in range(1, total):
# 读取波段数据
data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=cols, buf_ysize=rows)
out_band = target.GetRasterBand(index)
no_data_value = dataset.GetRasterBand(index).GetNoDataValue() # 获取没有数据的点
if not (no_data_value is None):
out_band.SetNoDataValue(no_data_value)
out_band.WriteArray(data) # 写入数据到新影像中
out_band.FlushCache()
out_band.ComputeBandStats(False) # 计算统计信息
del dataset
del target
return True
@staticmethod
def resampling_array(input_path, refer_img_path):
"""
按照缩放比例对影像重采样
:param input_path: GDAL地理数据路径
:param refer_img_path:参考影像
:return: True or False
"""
ref_dataset = gdal.Open(refer_img_path)
ref_cols = ref_dataset.RasterXSize # 列数
ref_rows = ref_dataset.RasterYSize # 行数
dataset = gdal.Open(input_path) # 判断数据是否存在
if dataset is None:
logger.error('resampling_by_scale:dataset is None!')
return False
band_count = dataset.RasterCount # 波段数
if band_count == 0:
logger.error("resampling_by_scale:Parameters of the abnormal!")
return False
cols = dataset.RasterXSize # 列数
scale = ref_cols/cols # 参考图像的分辨率
cols = ref_cols
rows = ref_rows
geotrans = list(dataset.GetGeoTransform())
geotrans[1] = geotrans[1] / scale # 像元宽度变为原来的scale倍
geotrans[5] = geotrans[5] / scale # 像元高度变为原来的scale倍
total = band_count
data = np.zeros((rows, cols, band_count), dtype=float)
for index in range(0, total):
array = dataset.GetRasterBand(index+1).ReadAsArray(buf_xsize=cols, buf_ysize=rows).astype(float)
data[:, :, index] = array
del dataset
return data