320 lines
12 KiB
Python
320 lines
12 KiB
Python
# -*- coding: UTF-8 -*-
|
||
"""
|
||
@Project :microproduct
|
||
@File :VegetationHeightPrePro.py
|
||
@Function :@Function: 坐标转换,坐标系转换,图像裁剪,重投影,重采样
|
||
@Author :LMM
|
||
@Date :2021/8/25 14:17
|
||
@Version :1.0.0
|
||
"""
|
||
|
||
from shapely.geometry import Polygon # 导入 gdal库要放到这一句的后面,不然会引起错误
|
||
import logging
|
||
# import shapely
|
||
from osgeo import gdal
|
||
from osgeo import osr
|
||
from osgeo import ogr
|
||
import os
|
||
import shapefile
|
||
from shapely.errors import TopologicalError
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
logger = logging.getLogger("mylog")
|
||
env_str = os.getcwd()
|
||
os.environ['PROJ_LIB'] = env_str
|
||
|
||
|
||
class PreProcess:
|
||
"""
|
||
预处理,所有的影像配准
|
||
"""
|
||
def __init__(self):
|
||
pass
|
||
|
||
@staticmethod
|
||
def lonlat2geo(lat, lon):
|
||
"""
|
||
WGS84转平面坐标
|
||
Param: lat 为WGS_1984的纬度
|
||
Param: lon 为WGS_1984的经度
|
||
输出转换后的坐标x,y
|
||
"""
|
||
|
||
dstsrs1 = osr.SpatialReference()
|
||
dstsrs1.ImportFromEPSG(32649)
|
||
# print("输出投影pro_经纬度:",dstSRS1)
|
||
|
||
dstsrs2 = osr.SpatialReference()
|
||
dstsrs2.ImportFromEPSG(4326)
|
||
# print("输出投影:",dstSRS2)
|
||
|
||
ct = osr.CoordinateTransformation(dstsrs2, dstsrs1)
|
||
coords = ct.TransformPoint(lat, lon)
|
||
# print("输出转换后的坐标x,y:",coords[:2])
|
||
return coords[:2]
|
||
#
|
||
# def transtif2mask(self, out_tif_path, in_tif_path, threshold):
|
||
# """
|
||
# :param out_tif_path:输出路径
|
||
# :param in_tif_path:输入的路径
|
||
# :param threshold:阈值
|
||
# """
|
||
# im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
|
||
# im_arr_mask = (im_arr < threshold).astype(int)
|
||
# self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)
|
||
|
||
@staticmethod
|
||
def trans_geogcs2projcs(out_path, in_path):
|
||
"""
|
||
:param out_path:wgs84投影坐标影像保存路径
|
||
:param in_path:地理坐标影像输入路径
|
||
"""
|
||
# 创建文件
|
||
if os.path.exists(os.path.split(out_path)[0]) is False:
|
||
os.makedirs(os.path.split(out_path)[0])
|
||
options = gdal.WarpOptions(format='GTiff', srcSRS='EPSG:4326', dstSRS='EPSG:32649')
|
||
gdal.Warp(out_path, in_path, options=options)
|
||
|
||
@staticmethod
|
||
def trans_projcs2geogcs(out_path, in_path):
|
||
"""
|
||
:param out_path:wgs84投影坐标影像保存路径
|
||
:param in_path:地理坐标影像输入路径
|
||
"""
|
||
# 创建文件
|
||
if os.path.exists(os.path.split(out_path)[0]) is False:
|
||
os.makedirs(os.path.split(out_path)[0])
|
||
options = gdal.WarpOptions(format='GTiff', srcSRS='EPSG:32649', dstSRS='EPSG:4326')
|
||
gdal.Warp(out_path, in_path, options=options)
|
||
|
||
@staticmethod
|
||
def intersect_polygon(scopes_tuple):
|
||
"""
|
||
功能说明:计算多边形相交的区域坐标;注意:多边形区域会转变成凸区域再求交
|
||
:param scopes_tuple: 输入多个区域坐标的tuple
|
||
:return: 多边形相交的区域坐标((x0,y0),(x1,y1),..., (xn,yn))
|
||
"""
|
||
if len(scopes_tuple) < 2:
|
||
logger.error('len(scopes_tuple) < 2')
|
||
return
|
||
try:
|
||
# python四边形对象,会自动计算四个点,最后四个点顺序为:左上 左下 右下 右上 左上
|
||
tmp = tuple(scopes_tuple[0])
|
||
poly_intersect = Polygon(tmp).convex_hull # 计算四点顺序
|
||
for i in range(len(scopes_tuple)-1):
|
||
polygon_next = Polygon(tuple(scopes_tuple[i+1])).convex_hull
|
||
if poly_intersect.intersects(polygon_next): # 相不相交
|
||
poly_intersect = poly_intersect.intersection(polygon_next) # 相交面积
|
||
else:
|
||
msg = 'Image:' + str(i) + 'range does not overlap!'
|
||
logger.error(msg)
|
||
return
|
||
return list(poly_intersect.boundary.coords)[:-1]
|
||
except shapely.geos.TopologicalError:
|
||
logger.error('shapely.geos.TopologicalError occurred!')
|
||
return
|
||
|
||
@staticmethod
|
||
def write_polygon_shp(out_shp_path, point_list):
|
||
"""
|
||
功能说明:创建闭环的矢量文件。
|
||
:param out_shp_path :矢量文件保存路径
|
||
:param point_list :装有闭环点的列表[[x0,y0],[x1,y1]...[xn,yn]]
|
||
:return: True or False
|
||
"""
|
||
# 为了支持中文路径,请添加下面这句代码
|
||
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
|
||
# 为了使属性表字段支持中文,请添加下面这句
|
||
gdal.SetConfigOption("SHAPE_ENCODING", "")
|
||
# 注册所有的驱动
|
||
ogr.RegisterAll()
|
||
|
||
# 创建数据,这里以创建ESRI的shp文件为例
|
||
str_driver_name = "ESRI Shapefile"
|
||
o_driver = ogr.GetDriverByName(str_driver_name)
|
||
if o_driver is None:
|
||
msg = 'driver('+str_driver_name+')is invalid value'
|
||
logger.error(msg)
|
||
return False
|
||
|
||
# 创建数据源
|
||
if os.path.exists(out_shp_path) and os.path.isfile(out_shp_path): # 如果已存在同名文件
|
||
os.remove(out_shp_path) # 则删除之
|
||
o_ds = o_driver.CreateDataSource(out_shp_path)
|
||
if o_ds is None:
|
||
msg = 'create file failed!' + out_shp_path
|
||
logger.error(msg)
|
||
return False
|
||
|
||
# 创建图层,创建一个多边形图层
|
||
srs = osr.SpatialReference()
|
||
srs.ImportFromEPSG(4326) # 计算ztd需要用到经纬度,所以此处采用经纬度坐标系
|
||
# srs.ImportFromEPSG(32649) # 投影坐标系,空间参考:WGS84
|
||
o_layer = o_ds.CreateLayer("TestPolygon", srs, ogr.wkbPolygon)
|
||
if o_layer is None:
|
||
msg = 'create coverage failed!'
|
||
logger.error(msg)
|
||
return False
|
||
|
||
# 下面创建属性表
|
||
# 先创建一个叫FieldID的整型属性
|
||
o_field_id = ogr.FieldDefn("FieldID", ogr.OFTInteger)
|
||
o_layer.CreateField(o_field_id, 1)
|
||
|
||
# 再创建一个叫FeatureName的字符型属性,字符长度为50
|
||
o_field_name = ogr.FieldDefn("FieldName", ogr.OFTString)
|
||
o_field_name.SetWidth(100)
|
||
o_layer.CreateField(o_field_name, 1)
|
||
|
||
o_defn = o_layer.GetLayerDefn()
|
||
|
||
# 创建矩形要素
|
||
o_feature_rectangle = ogr.Feature(o_defn)
|
||
o_feature_rectangle.SetField(0, 1)
|
||
o_feature_rectangle.SetField(1, "IntersectRegion")
|
||
|
||
# 创建环对象ring
|
||
ring = ogr.Geometry(ogr.wkbLinearRing)
|
||
|
||
for i in range(len(point_list)):
|
||
ring.AddPoint(point_list[i][0], point_list[i][1])
|
||
ring.CloseRings()
|
||
|
||
# 创建环对象polygon
|
||
geom_rect_polygon = ogr.Geometry(ogr.wkbPolygon)
|
||
geom_rect_polygon.AddGeometry(ring)
|
||
|
||
o_feature_rectangle.SetGeometry(geom_rect_polygon)
|
||
o_layer.CreateFeature(o_feature_rectangle)
|
||
|
||
o_ds.Destroy()
|
||
return True
|
||
|
||
@staticmethod
|
||
def cut_img(output_path, input_path, shp_path):
|
||
"""
|
||
:param output_path:剪切后的影像
|
||
:param input_path:待剪切的影像
|
||
:param shp_path:矢量数据
|
||
:return: True or False
|
||
"""
|
||
r = shapefile.Reader(shp_path)
|
||
box = r.bbox
|
||
|
||
input_dataset = gdal.Open(input_path)
|
||
|
||
gdal.Warp(output_path, input_dataset, format='GTiff', outputBounds=box, cutlineDSName=shp_path, dstNodata=-9999)
|
||
# cutlineWhere="FIELD = ‘whatever’",
|
||
# optionally you can filter your cutline (shapefile) based on attribute values
|
||
# select the no data value you like
|
||
# ds = None
|
||
# do other stuff with ds object, it is your cropped dataset. in this case we only close the dataset.
|
||
del input_dataset
|
||
return True
|
||
|
||
@staticmethod
|
||
def resampling_by_scale(input_path, target_file, refer_img_path):
|
||
"""
|
||
按照缩放比例对影像重采样
|
||
:param input_path: GDAL地理数据路径
|
||
:param target_file: 输出影像
|
||
:param refer_img_path:参考影像
|
||
:return: True or False
|
||
"""
|
||
ref_dataset = gdal.Open(refer_img_path)
|
||
ref_cols = ref_dataset.RasterXSize # 列数
|
||
ref_rows = ref_dataset.RasterYSize # 行数
|
||
|
||
dataset = gdal.Open(input_path) # 判断数据是否存在
|
||
if dataset is None:
|
||
logger.error('resampling_by_scale:dataset is None!')
|
||
return False
|
||
|
||
band_count = dataset.RasterCount # 波段数
|
||
if (band_count == 0) or (target_file == ""):
|
||
logger.error("resampling_by_scale:Parameters of the abnormal!")
|
||
return False
|
||
|
||
cols = dataset.RasterXSize # 列数
|
||
scale = ref_cols/cols # 参考图像的分辨率
|
||
|
||
# rows = dataset.RasterYSize # 行数
|
||
# cols = int(cols * scale) # 计算新的行列数
|
||
# rows = int(rows * scale)
|
||
cols = ref_cols
|
||
rows = ref_rows
|
||
|
||
geotrans = list(dataset.GetGeoTransform())
|
||
geotrans[1] = geotrans[1] / scale # 像元宽度变为原来的scale倍
|
||
geotrans[5] = geotrans[5] / scale # 像元高度变为原来的scale倍
|
||
|
||
if os.path.exists(target_file) and os.path.isfile(target_file): # 如果已存在同名影像
|
||
os.remove(target_file) # 则删除之
|
||
if not os.path.exists(os.path.split(target_file)[0]):
|
||
os.makedirs(os.path.split(target_file)[0])
|
||
|
||
band1 = dataset.GetRasterBand(1)
|
||
data_type = band1.DataType
|
||
target = dataset.GetDriver().Create(target_file, xsize=cols, ysize=rows, bands=band_count,
|
||
eType=data_type)
|
||
target.SetProjection(dataset.GetProjection()) # 设置投影坐标
|
||
target.SetGeoTransform(geotrans) # 设置地理变换参数
|
||
total = band_count + 1
|
||
for index in range(1, total):
|
||
# 读取波段数据
|
||
data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=cols, buf_ysize=rows)
|
||
out_band = target.GetRasterBand(index)
|
||
|
||
no_data_value = dataset.GetRasterBand(index).GetNoDataValue() # 获取没有数据的点
|
||
if not (no_data_value is None):
|
||
out_band.SetNoDataValue(no_data_value)
|
||
|
||
out_band.WriteArray(data) # 写入数据到新影像中
|
||
out_band.FlushCache()
|
||
out_band.ComputeBandStats(False) # 计算统计信息
|
||
|
||
del dataset
|
||
del target
|
||
return True
|
||
|
||
@staticmethod
|
||
def resampling_array(input_path, refer_img_path):
|
||
"""
|
||
按照缩放比例对影像重采样
|
||
:param input_path: GDAL地理数据路径
|
||
:param refer_img_path:参考影像
|
||
:return: True or False
|
||
"""
|
||
ref_dataset = gdal.Open(refer_img_path)
|
||
ref_cols = ref_dataset.RasterXSize # 列数
|
||
ref_rows = ref_dataset.RasterYSize # 行数
|
||
|
||
dataset = gdal.Open(input_path) # 判断数据是否存在
|
||
if dataset is None:
|
||
logger.error('resampling_by_scale:dataset is None!')
|
||
return False
|
||
|
||
band_count = dataset.RasterCount # 波段数
|
||
if band_count == 0:
|
||
logger.error("resampling_by_scale:Parameters of the abnormal!")
|
||
return False
|
||
|
||
cols = dataset.RasterXSize # 列数
|
||
scale = ref_cols/cols # 参考图像的分辨率
|
||
cols = ref_cols
|
||
rows = ref_rows
|
||
|
||
geotrans = list(dataset.GetGeoTransform())
|
||
geotrans[1] = geotrans[1] / scale # 像元宽度变为原来的scale倍
|
||
geotrans[5] = geotrans[5] / scale # 像元高度变为原来的scale倍
|
||
|
||
total = band_count
|
||
data = np.zeros((rows, cols, band_count), dtype=float)
|
||
for index in range(0, total):
|
||
array = dataset.GetRasterBand(index+1).ReadAsArray(buf_xsize=cols, buf_ysize=rows).astype(float)
|
||
data[:, :, index] = array
|
||
del dataset
|
||
|
||
return data
|