microproduct/tool/algorithm/algtools/PreProcess.py

554 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: UTF-8 -*-
"""
@Project microproduct
@File PreProcess.py
@Function @Function: 坐标转换,坐标系转换,图像裁剪,重投影,重采样
@Author LMM
@Date 2021/8/25 14:17
@Version 1.0.0
"""
from shapely.geometry import Polygon # 导入 gdal库要放到这一句的后面不然会引起错误
from osgeo import gdal
from osgeo import gdalconst
from osgeo import osr
from osgeo import ogr
import os
import cv2
import numpy as np
import shutil
import scipy.spatial.transform
import scipy.spatial.transform._rotation_groups # 用于解决打包错误
import scipy.special.cython_special # 用于解决打包错误
import scipy.spatial.transform._rotation_groups # 解决打包的问题
import shapefile
from shapely.errors import TopologicalError
from tool.algorithm.image.ImageHandle import ImageHandler
import logging
logger = logging.getLogger("mylog")
#
# os.environ['PROJ_LIB'] = os.getcwd()
class PreProcess:
"""
预处理,所有的影像配准
"""
def __init__(self):
self._ImageHandler = ImageHandler()
pass
def cal_scopes(self, processing_paras):
# 计算roi
scopes = ()
for key, value in processing_paras.items():
if 'ori_sim' in key:
scopes += (ImageHandler.get_scope_ori_sim(value),)
if(processing_paras['box'] != "" or processing_paras['box'] != "empty"):
scopes += self.box2scope(processing_paras['box'])
return scopes
def cal_scopes_roi(self, processing_paras):
return self.intersect_polygon(self.cal_scopes(processing_paras))
def cut_geoimg(self,workspace_preprocessing_path, para_names_geo, processing_paras):
# print(os.environ['PROJ_LIB'])
self.check_img_projection(workspace_preprocessing_path, para_names_geo, processing_paras)
# 计算roi
scopes = self.cal_scopes(processing_paras)
# 计算图像的轮廓,并求相交区域
intersect_shp_path = os.path.join(workspace_preprocessing_path, 'IntersectPolygon.shp')
scopes_roi = self.cal_intersect_shp(intersect_shp_path, para_names_geo, processing_paras, scopes)
# 裁剪
# 裁剪图像:裁剪微波图像,裁剪其他图像
cutted_img_paths = self.cut_imgs(workspace_preprocessing_path, para_names_geo, processing_paras, intersect_shp_path)
return cutted_img_paths, scopes_roi
def preprocessing(self, para_names, ref_img_name, processing_paras, workspace_preprocessing_path, workspace_preprocessed_path):
# 读取每一张图像,检查图像坐标系
self.check_img_projection(workspace_preprocessing_path, para_names, processing_paras)
# 计算图像的轮廓,并求相交区域
intersect_shp_path = os.path.join(workspace_preprocessing_path, 'IntersectPolygon.shp')
self.cal_intersect_shp(intersect_shp_path, para_names, processing_paras,
self.box2scope(processing_paras['box']))
logger.info('create intersect shp success!')
# 裁剪图像:裁剪微波图像,裁剪其他图像
cutted_img_paths = self.cut_imgs(workspace_preprocessing_path, para_names, processing_paras,
intersect_shp_path)
logger.info('cut images success!')
# 重采样:重采样到跟微波图像一致的分辨率,然后保存到临时目录
preprocessed_paras = self.resampling_img(workspace_preprocessed_path, para_names, cutted_img_paths,cutted_img_paths[ref_img_name])
# 清除预处理缓存文件
logger.info('preprocess_handle success!')
return preprocessed_paras # cutted_img_paths
def get_ref_inf(self, ref_img_path):
"""获取参考影像的图像信息"""
ref_img_path = ref_img_path
cols = ImageHandler.get_img_width(ref_img_path)
rows = ImageHandler.get_img_height(ref_img_path)
proj = ImageHandler.get_projection(ref_img_path)
geo = ImageHandler.get_geotransform(ref_img_path)
return ref_img_path, cols, rows, proj, geo
def check_img_projection(self, out_dir, para_names, processing_paras):
"""
读取每一张图像,检查图像坐标系;
将投影坐标系影像转换为地理坐标系影像(EPSG:4326)
:param para_names:需要检查的参数名称
"""
if len(para_names) == 0:
return False
for name in para_names:
proj = ImageHandler.get_projection(processing_paras[name])
keyword = proj.split("[", 2)[0]
if keyword == "PROJCS":
# 投影坐标系 转 地理坐标系
para_dir = os.path.split(processing_paras[name])
out_para = os.path.join(out_dir, para_dir[1].split(".", 1)[0] + "_EPSG4326.tif")
self.trans_epsg4326(out_para, processing_paras[name])
processing_paras[name] = out_para
elif len(keyword) == 0 or keyword.strip() == "" or keyword.isspace() is True:
raise Exception('coordinate is missing!')
def preprocessing_oh2004(self, para_names, processing_paras, workspace_preprocessing_path, workspace_preprocessed_path):
# 读取每一张图像,检查图像坐标系
self.check_img_projection(workspace_preprocessing_path, para_names, processing_paras)
# 计算图像的轮廓,并求相交区域
intersect_shp_path = os.path.join(workspace_preprocessing_path, 'IntersectPolygon.shp')
scopes = self.cal_intersect_shp(intersect_shp_path, para_names, processing_paras,
self.box2scope(processing_paras['box']))
logger.info('create intersect shp success!')
# 裁剪图像:裁剪微波图像,裁剪其他图像
cutted_img_paths = self.cut_imgs(workspace_preprocessed_path, para_names, processing_paras,
intersect_shp_path)
logger.info('cut images success!')
# 重采样:重采样到跟微波图像一致的分辨率,然后保存到临时目录
return cutted_img_paths, scopes
@staticmethod
def lonlat2geo(lat, lon):
"""
WGS84转平面坐标
Param: lat 为WGS_1984的纬度
Param: lon 为WGS_1984的经度
输出转换后的坐标x,y
"""
dstsrs1 = osr.SpatialReference()
dstsrs1.ImportFromEPSG(32649)
dstsrs2 = osr.SpatialReference()
dstsrs2.ImportFromEPSG(4326)
ct = osr.CoordinateTransformation(dstsrs2, dstsrs1)
coords = ct.TransformPoint(lat, lon)
# print("输出转换后的坐标x,y:",coords[:2])
return coords[:2]
@staticmethod
def trans_geogcs2projcs(out_path, in_path):
"""
:param out_path:wgs84投影坐标影像保存路径
:param in_path:地理坐标影像输入路径
"""
# 创建文件
if os.path.exists(os.path.split(out_path)[0]) is False:
os.makedirs(os.path.split(out_path)[0])
options = gdal.WarpOptions(format='GTiff', srcSRS='EPSG:4326', dstSRS='EPSG:32649')
gdal.Warp(out_path, in_path, options=options)
@staticmethod
def trans_projcs2geogcs(out_path, in_path):
"""
:param out_path:wgs84地理坐标影像输入路径
:param in_path:wgs84投影坐标影像保存路径
"""
# 创建文件
if os.path.exists(os.path.split(out_path)[0]) is False:
os.makedirs(os.path.split(out_path)[0])
options = gdal.WarpOptions(format='GTiff', srcSRS='EPSG:32649', dstSRS='EPSG:4326')
gdal.Warp(out_path, in_path, options=options)
@staticmethod
def trans_projcs2geogcs(out_path, in_path ,EPSG_src=32649,EPSG_dst=4326):
"""
:param out_path:wgs84地理坐标影像输入路径
:param in_path:wgs84投影坐标影像保存路径
:param EPSG_src:原始投影系
:param EPSG_dst:目标坐标系
"""
str_EPSG_src = 'EPSG:'+ str(EPSG_src)
str_EPSG_dst = 'EPSG:'+ str(EPSG_dst)
# 创建文件
if os.path.exists(os.path.split(out_path)[0]) is False:
os.makedirs(os.path.split(out_path)[0])
options = gdal.WarpOptions(format='GTiff', srcSRS=str_EPSG_src, dstSRS=str_EPSG_dst)
gdal.Warp(out_path, in_path, options=options)
@staticmethod
def trans_epsg4326(out_path, in_path):
OutTile = gdal.Warp(out_path, in_path,
dstSRS='EPSG:4326',
resampleAlg=gdalconst.GRA_Bilinear
)
OutTile = None
return True
@staticmethod
def box2scope(str_box):
roi_box = ()
if str_box == '' or str_box == 'empty':
return roi_box
box_list = [float(num) for num in list(str_box.split(';'))]
if len(box_list) == 4:
roi_box = ([[box_list[2], box_list[1]], [box_list[3], box_list[1]], [box_list[2], box_list[0]],
[box_list[3], box_list[0]]],)
return roi_box
def cal_intersect_shp(self, shp_path, para_names,processing_paras, add_scope =()):
"""
:param shp_path:相交区域矢量文件保存区域
:param para_names:判断相交影像的名称
:return: True or False
"""
scopes = ()
if len(add_scope) != 0:
scopes += add_scope
for name in para_names:
scope_tuple = (self._ImageHandler.get_scope(processing_paras[name]),)
scopes += scope_tuple
for n, scope in zip( range(len(scopes)), scopes):
logging.info("scope" + str(n) + ":%s", scope)
intersect_polygon = self.intersect_polygon(scopes)
if intersect_polygon is None:
logger.error('image range does not overlap!')
raise Exception('create intersect shp fail!')
logging.info("scope roi :%s", intersect_polygon)
if self.write_polygon_shp(shp_path, intersect_polygon, 4326) is False:
raise Exception('create intersect shp fail!')
return intersect_polygon
@staticmethod
def intersect_polygon(scopes_tuple):
"""
功能说明:计算多边形相交的区域坐标;注意:多边形区域会转变成凸区域再求交
:param scopes_tuple: 输入多个区域坐标的tuple
:return: 多边形相交的区域坐标((x0,y0),(x1,y1),..., (xn,yn))
"""
if len(scopes_tuple) < 2:
logger.error('len(scopes_tuple) < 2')
# return # todo 修改只有单景会出现无法判断相交区域问题
try:
# python四边形对象会自动计算四个点最后四个点顺序为左上 左下 右下 右上 左上
tmp = tuple(scopes_tuple[0])
poly_intersect = Polygon(tmp).convex_hull
for i in range(len(scopes_tuple)-1):
polygon_next = Polygon(tuple(scopes_tuple[i+1])).convex_hull
if poly_intersect.intersects(polygon_next):
poly_intersect = poly_intersect.intersection(polygon_next)
else:
msg = 'Image:' + str(i) + 'range does not overlap!'
logger.error(msg)
return
return list(poly_intersect.boundary.coords)[:-1]
# except shapely.geos.TopologicalError:
except TopologicalError:
logger.error('shapely.geos.TopologicalError occurred!')
return
@staticmethod
def write_polygon_shp(out_shp_path, point_list, EPSG =32649):
"""
功能说明:创建闭环的矢量文件。
:param out_shp_path :矢量文件保存路径
:param point_list :装有闭环点的列表[[x0,y0],[x1,y1]...[xn,yn]]
:return: True or False
"""
# 为了支持中文路径,请添加下面这句代码
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
# 为了使属性表字段支持中文,请添加下面这句
gdal.SetConfigOption("SHAPE_ENCODING", "")
# 注册所有的驱动
ogr.RegisterAll()
# 创建数据这里以创建ESRI的shp文件为例
str_driver_name = "ESRI Shapefile"
o_driver = ogr.GetDriverByName(str_driver_name)
if o_driver is None:
msg = 'driver('+str_driver_name+')is invalid value'
logger.error(msg)
return False
# 创建数据源
if os.path.exists(out_shp_path) and os.path.isfile(out_shp_path): # 如果已存在同名文件
os.remove(out_shp_path) # 则删除之
o_ds = o_driver.CreateDataSource(out_shp_path)
if o_ds is None:
msg = 'create file failed!' + out_shp_path
logger.error(msg)
return False
# 创建图层,创建一个多边形图层
srs = osr.SpatialReference()
#srs.ImportFromEPSG(32649) # 投影坐标系空间参考WGS84
srs.ImportFromEPSG(EPSG) # 地理坐标系EPSG
o_layer = o_ds.CreateLayer("TestPolygon", srs, ogr.wkbPolygon)
if o_layer is None:
msg = 'create coverage failed!'
logger.error(msg)
return False
# 下面创建属性表
# 先创建一个叫FieldID的整型属性
o_field_id = ogr.FieldDefn("FieldID", ogr.OFTInteger)
o_layer.CreateField(o_field_id, 1)
# 再创建一个叫FeatureName的字符型属性字符长度为50
o_field_name = ogr.FieldDefn("FieldName", ogr.OFTString)
o_field_name.SetWidth(100)
o_layer.CreateField(o_field_name, 1)
o_defn = o_layer.GetLayerDefn()
# 创建矩形要素
o_feature_rectangle = ogr.Feature(o_defn)
o_feature_rectangle.SetField(0, 1)
o_feature_rectangle.SetField(1, "IntersectRegion")
# 创建环对象ring
ring = ogr.Geometry(ogr.wkbLinearRing)
for i in range(len(point_list)):
ring.AddPoint(point_list[i][0], point_list[i][1])
ring.CloseRings()
# 创建环对象polygon
geom_rect_polygon = ogr.Geometry(ogr.wkbPolygon)
geom_rect_polygon.AddGeometry(ring)
o_feature_rectangle.SetGeometry(geom_rect_polygon)
o_layer.CreateFeature(o_feature_rectangle)
o_ds.Destroy()
return True
def cut_imgs(self, out_dir, para_names, processing_paras, shp_path):
"""
使用矢量数据裁剪影像
:param para_names:需要检查的参数名称
:param shp_path裁剪的shp文件
"""
if len(para_names) == 0:
return {}
cutted_img_paths = {}
try:
for name in para_names:
input_path = processing_paras[name]
output_path = os.path.join(out_dir, name + '_cut.tif')
self.cut_img(output_path, input_path, shp_path)
cutted_img_paths.update({name: output_path})
logger.info('cut %s success!', name)
except BaseException:
logger.error('cut_img failed!')
return {}
return cutted_img_paths
def cut_imgs_VP(self, out_dir, para_names, processing_paras, shp_path, img_name):
"""
使用矢量数据裁剪影像
:param para_names:需要检查的参数名称
:param shp_path裁剪的shp文件
"""
if len(para_names) == 0:
return {}
cutted_img_paths = {}
try:
for name in para_names:
if name == 'Covering':
img_name = img_name.split('_')[6] + '_'
output_path = os.path.join(out_dir, img_name + name + '_cut.tif')
else:
output_path = os.path.join(out_dir, name + '_cut.tif')
input_path = processing_paras[name]
self.cut_img(output_path, input_path, shp_path)
cutted_img_paths.update({name: output_path})
logger.info('cut %s success!', name)
except BaseException:
logger.error('cut_img failed!')
return {}
return cutted_img_paths
@staticmethod
def cut_img(output_path, input_path, shp_path):
"""
:param output_path:剪切后的影像
:param input_path:待剪切的影像
:param shp_path:矢量数据
:return: True or False
"""
r = shapefile.Reader(shp_path)
box = r.bbox
input_dataset = gdal.Open(input_path)
gdal.Warp(output_path, input_dataset, format='GTiff', outputBounds=box, cutlineDSName=shp_path, dstNodata=-9999)
# cutlineWhere="FIELD = whatever",
# optionally you can filter your cutline (shapefile) based on attribute values
# select the no data value you like
# ds = None
# do other stuff with ds object, it is your cropped dataset. in this case we only close the dataset.
del input_dataset
return True
def resampling_img(self, out_dir, para_names, img_paths, refer_img_path):
"""
以主影像为参考,对影像重采样
:param para_names:需要检查的参数名称
:param img_paths待重采样影像路径
:param refer_img_path参考影像路径
"""
if len(para_names) == 0 or len(img_paths) == 0:
return
prepro_imgs_path = {}
for name in para_names:
img_path = img_paths[name]
output_para = os.path.join(out_dir, name + '_preprocessed.tif') # + name + '_preprocessed.tif'
self.resampling_by_scale(img_path, output_para, refer_img_path)
prepro_imgs_path.update({name: output_para})
logger.info('resampling %s success!', name)
return prepro_imgs_path
@staticmethod
def resampling_by_scale(input_path, target_file, refer_img_path):
"""
按照缩放比例对影像重采样
:param input_path: GDAL地理数据路径
:param target_file: 输出影像
:param refer_img_path:参考影像
:return: True or False
"""
ref_dataset = gdal.Open(refer_img_path)
ref_cols = ref_dataset.RasterXSize # 列数
ref_rows = ref_dataset.RasterYSize # 行数
target_dataset = gdal.Open(input_path)
target_cols = target_dataset.RasterXSize # 列数
target_rows = target_dataset.RasterYSize # 行数
if(ref_cols == target_cols) and (ref_rows == target_rows):
shutil.copyfile(input_path, target_file)
return True
dataset = gdal.Open(input_path)
if dataset is None:
logger.error('resampling_by_scale:dataset is None!')
return False
band_count = dataset.RasterCount # 波段数
if (band_count == 0) or (target_file == ""):
logger.error("resampling_by_scale:Parameters of the abnormal!")
return False
cols = dataset.RasterXSize # 列数
rows = dataset.RasterYSize # 行数
scale_x = ref_cols/cols
scale_y = ref_rows/rows
# rows = dataset.RasterYSize # 行数
# cols = int(cols * scale) # 计算新的行列数
# rows = int(rows * scale)
cols = ref_cols
rows = ref_rows
geotrans = list(dataset.GetGeoTransform())
geotrans[1] = geotrans[1] / scale_x # 像元宽度变为原来的scale倍
geotrans[5] = geotrans[5] / scale_y # 像元高度变为原来的scale倍
if os.path.exists(target_file) and os.path.isfile(target_file): # 如果已存在同名影像
os.remove(target_file) # 则删除之
if not os.path.exists(os.path.split(target_file)[0]):
os.makedirs(os.path.split(target_file)[0])
band1 = dataset.GetRasterBand(1)
data_type = band1.DataType
target = dataset.GetDriver().Create(target_file, xsize=cols, ysize=rows, bands=band_count,
eType=data_type)
target.SetProjection(dataset.GetProjection()) # 设置投影坐标
target.SetGeoTransform(geotrans) # 设置地理变换参数
total = band_count + 1
for index in range(1, total):
# 读取波段数据
data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=cols, buf_ysize=rows)
out_band = target.GetRasterBand(index)
no_data_value = dataset.GetRasterBand(index).GetNoDataValue() # 获取没有数据的点
if not (no_data_value is None):
out_band.SetNoDataValue(no_data_value)
out_band.WriteArray(data) # 写入数据到新影像中
out_band.FlushCache()
out_band.ComputeBandStats(False) # 计算统计信息
del dataset
del target
return True
@staticmethod
def cv_mean_filter(out_path, in_path, filter_size):
"""
:param out_path:滤波后的影像
:param in_path:滤波前的影像
:param filter_size:滤波尺寸
:return: True or False
"""
proj = ImageHandler.get_projection(in_path)
geotrans = ImageHandler.get_geotransform(in_path)
array = ImageHandler.get_band_array(in_path, 1)
array = cv2.blur(array, (filter_size, filter_size)) # 均值滤波
ImageHandler.write_img(out_path, proj, geotrans, array)
return True
@staticmethod
def check_LocalIncidenceAngle(out_tif_path, in_tif_path):
"""
将角度的无效值设置为nan把角度值转为弧度值
:param out_tif_path:处理后影像路径
:param in_tif_path:处理前影像路径
"""
proj, geo, angle = ImageHandler.read_img(in_tif_path)
angle = angle.astype(np.float32, order='C')
angle[angle == -9999] = np.nan
mean = np.nanmean(angle)
if mean > np.pi:
angle = np.deg2rad(angle)# 角度转弧度
angle[np.where(angle >= 0.5 * np.pi)] = np.nan
angle[np.where(angle < 0)] = np.nan
ImageHandler.write_img(out_tif_path, proj, geo, angle)