microproduct-l-sar/soilSalinity-Train_predict/SoilSalinityMain.py

456 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: UTF-8 -*-
"""
@Project microproduct
@File SoilSalinityMain.py
@Author SHJ
@Contact土壤盐碱度算法主函数
@Date 2021/9/6
@Version 1.0.0
修改历史:
[修改序列] [修改日期] [修改者] [修改内容]
1 2022-6-27 石海军 1.增加配置文件config.ini; 2.内部处理使用地理坐标系(4326)
"""
import logging
import shutil
from tool.algorithm.algtools.MetaDataHandler import Calibration
from tool.algorithm.algtools.PreProcess import PreProcess as pp
from tool.algorithm.image.ImageHandle import ImageHandler
from tool.algorithm.polsarpro.pspLeeRefinedFilterT3 import LeeRefinedFilterT3
from tool.algorithm.xml.AlgXmlHandle import ManageAlgXML, CheckSource, InitPara
from tool.algorithm.algtools.logHandler import LogHandler
from tool.algorithm.algtools.ROIAlg import ROIAlg as roi
from tool.algorithm.block.blockprocess import BlockProcess
from tool.algorithm.xml.AnalysisXml import xml_extend
from tool.algorithm.xml.CreateMetaDict import CreateMetaDict, CreateProductXml
from tool.file.fileHandle import fileHandle
# from AHVToPolsarpro import AHVToPolsarpro
from tool.algorithm.polsarpro.AHVToPolsarpro import AHVToPolsarpro
from pspHAAlphaDecomposition import PspHAAlphaDecomposition
import scipy.spatial.transform
import scipy.spatial.transform._rotation_groups # 用于解决打包错误
import scipy.special.cython_special # 用于解决打包错误
from sklearn.cross_decomposition import PLSRegression
from tool.algorithm.xml.CreatMetafile import CreateMetafile
from SoilSalinityXmlInfo import CreateDict, CreateStadardXmlFile
import os
import datetime
import numpy as np
from PIL import Image
import sys
from tool.config.ConfigeHandle import Config as cf
from tool.csv.csvHandle import csvHandle
from tool.algorithm.transforml1a.transHandle import TransImgL1A
from tool.algorithm.ml.machineLearning import MachineLeaning as ml
import multiprocessing
csvh = csvHandle()
soil_salinity_value_min = float(cf.get('product_value_min'))
soil_salinity_value_max = float(cf.get('product_value_max'))
pixelspace=float(cf.get('pixelspace'))
tar = r'-' + cf.get('tar')
productLevel = cf.get('productLevel')
if cf.get('debug') == 'True':
DEBUG = True
else:
DEBUG = False
EXE_NAME = cf.get('exe_name')
# env_str = os.path.split(os.path.realpath(__file__))[0]
env_str =os.path.dirname(os.path.abspath(sys.argv[0]))
os.environ['PROJ_LIB'] = env_str
LogHandler.init_log_handler('run_log\\'+EXE_NAME)
logger = logging.getLogger("mylog")
file = fileHandle(DEBUG)
class SalinityMain:
"""
算法主函数
"""
def __init__(self, alg_xml_path):
self.alg_xml_path = alg_xml_path
self.imageHandler = ImageHandler()
self.__alg_xml_handler = ManageAlgXML(alg_xml_path)
self.__check_handler = CheckSource(self.__alg_xml_handler)
self.__workspace_path, self.__out_para = None, None
self.__input_paras, self.__output_paras, self.__processing_paras, self.__preprocessed_paras = {}, {}, {}, {}
self.__feature_name_list = []
# 参考影像路径,坐标系
self.__ref_img_path, self.__proj = '', ''
# 宽/列数,高/行数
self.__cols, self.__rows = 0, 0
# 影像投影变换矩阵
self.__geo = [0, 0, 0, 0, 0, 0]
def check_source(self):
"""
检查算法相关的配置文件,图像,辅助文件是否齐全
"""
env_str = os.getcwd()
logger.info("sysdir: %s", env_str)
self.__check_handler.check_alg_xml()
self.__check_handler.check_run_env()
# 检查景影像是否为全极化
self.__input_paras = self.__alg_xml_handler.get_input_paras()
if self.__check_handler.check_input_paras(self.__input_paras) is False:
return False
self.__workspace_path = self.__alg_xml_handler.get_workspace_path()
self.__create_work_space()
self.__processing_paras = InitPara.init_processing_paras(self.__input_paras)
self.__processing_paras.update(self.get_tar_gz_inf(self.__processing_paras["sar_path0"]))
SrcImageName = os.path.split(self.__input_paras["AHV"]['ParaValue'])[1].split('.tar.gz')[0]
result_name = SrcImageName + tar + ".tar.gz"
self.__out_para = os.path.join(self.__workspace_path, EXE_NAME, 'Output', result_name)
self.__alg_xml_handler.write_out_para("SoilSalinityProduct", self.__out_para) #写入输出参数
logger.info('check_source success!')
logger.info('progress bar: 10%')
return True
def get_tar_gz_inf(self, tar_gz_path):
para_dic = {}
name = os.path.split(tar_gz_path)[1].rstrip('.tar.gz')
file_dir = os.path.join(self.__workspace_preprocessing_path, name + '\\')
file.de_targz(tar_gz_path, file_dir)
# 元文件字典
# para_dic.update(InitPara.get_meta_dic(InitPara.get_meta_paths(file_dir, name), name))
para_dic.update(InitPara.get_meta_dic_new(InitPara.get_meta_paths(file_dir, name), name))
# tif路径字典
para_dic.update(InitPara.get_polarization_mode(InitPara.get_tif_paths(file_dir, name)))
parameter_path = os.path.join(file_dir, "orth_para.txt")
para_dic.update({"paraMeter": parameter_path})
return para_dic
def __create_work_space(self):
"""
删除原有工作区文件夹,创建新工作区文件夹
"""
self.__workspace_preprocessing_path = self.__workspace_path + EXE_NAME +'\\Temporary\\preprocessing\\'
self.__workspace_preprocessed_path = self.__workspace_path + EXE_NAME + '\\Temporary\\preprocessed\\'
self.__workspace_processing_path = self.__workspace_path + EXE_NAME + '\\Temporary\\processing\\'
self.__workspace_block_tif_path = self.__workspace_path + EXE_NAME + '\\Temporary\\blockTif\\'
self.__workspace_block_tif_processed_path = self.__workspace_path + EXE_NAME + '\\Temporary\\blockTifProcessed\\'
self.__product_dic = self.__workspace_processing_path + 'product\\'
path_list = [self.__workspace_preprocessing_path, self.__workspace_preprocessed_path,
self.__workspace_processing_path, self.__workspace_block_tif_path,
self.__workspace_block_tif_processed_path,self.__product_dic]
file.creat_dirs(path_list)
logger.info('create new workspace success!')
def del_temp_workspace(self):
"""
临时工作区
"""
if DEBUG is True:
return
path = self.__workspace_path + EXE_NAME + r"\Temporary"
if os.path.exists(path):
file.del_folder(path)
def preprocess_handle(self):
"""
预处理
"""
para_names_geo = ["Covering", "NDVI", 'sim_ori']
p = pp()
p.check_img_projection(self.__workspace_preprocessing_path, para_names_geo, self.__processing_paras)
#计算roi
scopes = ()
# scopes += (self.imageHandler.get_scope_ori_sim(self.__processing_paras['ori_sim']),)
scopes += (xml_extend(self.__processing_paras['META']).get_extend(),)
scopes += p.box2scope(self.__processing_paras['box'])
# 计算图像的轮廓,并求相交区域
intersect_shp_path = self.__workspace_preprocessing_path + 'IntersectPolygon.shp'
scopes_roi = p.cal_intersect_shp(intersect_shp_path, para_names_geo, self.__processing_paras, scopes)
#裁剪
# 裁剪图像:裁剪微波图像,裁剪其他图像
cutted_img_paths = p.cut_imgs(self.__workspace_preprocessing_path, para_names_geo, self.__processing_paras, intersect_shp_path)
self.__preprocessed_paras.update(cutted_img_paths)
para_names_l1a = ["HH", "VV", "HV", "VH"]
self.l1a_width = ImageHandler.get_img_width(self.__processing_paras['HH'])
self.l1a_height = ImageHandler.get_img_height(self.__processing_paras['HH'])
self._tr = TransImgL1A(self.__processing_paras['sim_ori'], scopes_roi, self.l1a_height, self.l1a_width) # 裁剪图像
for name in para_names_l1a:
out_path = os.path.join(self.__workspace_preprocessed_path, name + "_preprocessed.tif")
self._tr.cut_L1A(self.__processing_paras[name], out_path)
self.__preprocessed_paras.update({name: out_path})
logger.info('preprocess_handle success!')
logger.info('progress bar: 15%')
def resampleImgs(self, refer_img_path):
ndvi_rampling_path = self.__workspace_processing_path + "ndvi.tif"
pp.resampling_by_scale(self.__preprocessed_paras["NDVI"], ndvi_rampling_path, refer_img_path)
self.__preprocessed_paras["NDVI"] = ndvi_rampling_path
cover_rampling_path = self.__workspace_processing_path + "cover.tif"
pp.resampling_by_scale(self.__preprocessed_paras["Covering"], cover_rampling_path, refer_img_path)
self.__preprocessed_paras["Covering"] = cover_rampling_path
def create_roi(self):
"""
计算ROI掩膜
:return: 掩膜路径
"""
names = ['Covering', 'NDVI']
bare_land_mask_path = roi().roi_process(names, self.__workspace_processing_path + "/roi/", self.__processing_paras, self.__preprocessed_paras)
logger.info('create masks success!')
return bare_land_mask_path
def AHVToPolsarpro(self,out_dir):
atp = AHVToPolsarpro()
ahv_path = self.__workspace_preprocessed_path
t3_path = self.__workspace_processing_path+'psp_t3\\'
# atp.ahv_to_polsarpro_t3_soil(t3_path, ahv_path)
polarization = ['HH', 'HV', 'VH', 'VV']
calibration = Calibration.get_Calibration_coefficient(self.__processing_paras['Origin_META'], polarization)
tif_path = atp.calibration(calibration, in_ahv_dir=self.__workspace_preprocessed_path)
atp.ahv_to_polsarpro_t3_soil(t3_path, tif_path)
logger.info('ahv transform to polsarpro T3 matrix success!')
logger.info('progress bar: 20%')
# Lee滤波
leeFilter = LeeRefinedFilterT3()
lee_filter_path = os.path.join(self.__workspace_processing_path,
'lee_filter\\')
leeFilter.api_lee_refined_filter_T3('', t3_path, lee_filter_path, 0, 0, atp.rows(), atp.cols())
logger.info('Refined_lee process success!')
haa = PspHAAlphaDecomposition(normalization=True)
haa.api_creat_h_a_alpha_features(h_a_alpha_out_dir=out_dir,
h_a_alpha_decomposition_T3_path='h_a_alpha_decomposition_T3.exe' ,
h_a_alpha_eigenvalue_set_T3_path='h_a_alpha_eigenvalue_set_T3.exe' ,
h_a_alpha_eigenvector_set_T3_path='h_a_alpha_eigenvector_set_T3.exe',
polsarpro_in_dir=lee_filter_path)
def create_meta_file(self, product_path):
xml_path = "./model_meta.xml"
tem_folder = self.__workspace_path + EXE_NAME + r"\Temporary""\\"
image_path = product_path
out_path1 = os.path.join(tem_folder, "trans_geo_projcs.tif")
out_path2 = os.path.join(tem_folder, "trans_projcs_geo.tif")
# par_dict = CreateDict(image_path, [1, 1, 1, 1], out_path1, out_path2).calu_nature(start)
# model_xml_path = os.path.join(tem_folder, "creat_standard.meta.xml") # 输出xml路径
# CreateStadardXmlFile(xml_path, self.alg_xml_path, par_dict, model_xml_path).create_standard_xml()
# 文件夹打包
SrcImagePath = self.__input_paras["AHV"]['ParaValue']
paths = SrcImagePath.split(';')
SrcImageName = os.path.split(paths[0])[1].split('.tar.gz')[0]
# if len(paths) >= 2:
# for i in range(1, len(paths)):
# SrcImageName = SrcImageName + ";" + os.path.split(paths[i])[1].split('.tar.gz')[0]
# meta_xml_path = self.__product_dic + EXE_NAME + "Product.meta.xml"
# CreateMetafile(self.__processing_paras['META'], self.alg_xml_path, model_xml_path, meta_xml_path).process(
# SrcImageName)
model_path = "./product.xml"
meta_xml_path = os.path.join(self.__product_dic, SrcImageName + tar + ".meta.xml")
para_dict = CreateMetaDict(image_path, self.__processing_paras['Origin_META'], self.__workspace_processing_path,
out_path1, out_path2).calu_nature()
para_dict.update({"imageinfo_ProductName": "土壤盐碱度"})
para_dict.update({"imageinfo_ProductIdentifier": "SoilSalinity"})
para_dict.update({"imageinfo_ProductLevel": productLevel})
para_dict.update({"ProductProductionInfo_BandSelection": "1,2"})
CreateProductXml(para_dict, model_path, meta_xml_path).create_standard_xml()
temp_folder = os.path.join(self.__workspace_path, EXE_NAME, 'Output')
out_xml = os.path.join(temp_folder, os.path.basename(meta_xml_path))
if os.path.exists(temp_folder) is False:
os.mkdir(temp_folder)
# CreateProductXml(para_dict, model_path, out_xml).create_standard_xml()
shutil.copy(meta_xml_path, out_xml)
def calInterpolation_bil_Wgs84_rc_sar_sigma(self, parameter_path, dem_rc, in_sar, out_sar):
'''
# std::cout << "mode 11";
# std::cout << "SIMOrthoProgram.exe 11 in_parameter_path in_rc_wgs84_path in_ori_sar_path out_orth_sar_path";
'''
exe_path = r".\baseTool\x64\Release\SIMOrthoProgram.exe"
exe_cmd = r"set PROJ_LIB=.\baseTool\x64\Release; & {0} {1} {2} {3} {4} {5}".format(exe_path, 11, parameter_path,
dem_rc, in_sar, out_sar)
print(exe_cmd)
print(os.system(exe_cmd))
print("==========================================================================")
def process_handle(self, start):
"""
算法主处理函数
:return: True or False
"""
# 读取实测值,从经纬度坐标系转为图像坐标系
# measured_data_img = self._tr.tran_lonlats_to_L1A_rowcols(csvh.readcsv(self.__processing_paras['MeasuredData']), self.__processing_paras['ori_sim'])
measured_data_img = self._tr.tran_lonlats_to_L1A_rowcols(csvh.readcsv(self.__processing_paras['MeasuredData']),
self.__preprocessed_paras['sim_ori'], self.l1a_height,
self.l1a_width)
# if len(measured_data_img) < 4:
# raise ('实测数据不足,无法进行模型训练')
# 极化分解得到T3矩阵
out_dir = self.__workspace_processing_path+'psp_haalpha\\'
self.AHVToPolsarpro(out_dir)
# 分块
bp = BlockProcess()
rows = self.imageHandler.get_img_height(self.__preprocessed_paras['HH'])
cols = self.imageHandler.get_img_width(self.__preprocessed_paras['HH'])
block_size = bp.get_block_size(rows, cols)
bp.cut(out_dir, self.__workspace_block_tif_path, ['tif', 'tiff'], 'tif', block_size)
img_dir, img_name = bp.get_file_names(self.__workspace_block_tif_path, ['tif'])
dir_dict = bp.get_same_img(img_dir, img_name)
logger.info('blocking tifs success!')
# 54个特征矩阵合并为一个54维矩阵
for key in dir_dict:
key_name = key
block_num = len(dir_dict[key])
for n in range(block_num):
name = os.path.basename(dir_dict[key_name][n])
suffix = '_' + name.split('_')[-4] + "_" + name.split('_')[-3] + "_" + name.split('_')[-2] + "_" + \
name.split('_')[-1]
features_path = self.__workspace_block_tif_processed_path + "features\\features" + suffix
features_array = np.zeros((len(dir_dict), block_size, block_size), dtype='float32')
for m, value in zip(range(len(dir_dict)), dir_dict.values()):
features_array[m, :, :] = self.imageHandler.get_band_array(value[n], 1)
# 异常值转为0
features_array[np.isnan(features_array)] = 0.0
features_array[np.isinf(features_array)] = 0.0
self.imageHandler.write_img(features_path, "", [0, 0, 1, 0, 0, 1], features_array)
logger.info('create features matrix success!')
# for n in range(block_num):
# name = os.path.basename(dir_dict[key_name][n])
# suffix = '_' + name.split('_')[-4] + "_" + name.split('_')[-3] + "_" + name.split('_')[-2] + "_" + \
# name.split('_')[-1]
# features_path = self.__workspace_block_tif_processed_path + "features\\features" + suffix
# row = self.imageHandler.get_img_height(dir_dict[key_name][n])
# col = self.imageHandler.get_img_width(dir_dict[key_name][n])
# features_array = np.zeros((len(dir_dict), row, col), dtype='float32')
# for m, value in zip(range(len(dir_dict)), dir_dict.values()):
# features_array[m, :, :] = self.imageHandler.get_band_array(value[n], 1)
# # 异常值转为0
# features_array[np.isnan(features_array)] = 0.0
# features_array[np.isinf(features_array)] = 0.0
# self.imageHandler.write_img(features_path, "", [0, 0, 1, 0, 0, 1], features_array)
# logger.info('create features matrix success!')
# 生成训练集
block_features_dir, block_features_name = bp.get_file_names(self.__workspace_block_tif_processed_path + 'features\\', ['tif'])
X_train, Y_train = ml.gene_train_data(block_features_dir, rows, cols, block_size, measured_data_img)
optimal_feature = ml.sel_optimal_feature_set(X_train, Y_train, threshold=0.01)
optimal_feature = ml.remove_correlation_feature(X_train, optimal_feature, threshold=0.85)
X_train = X_train[:, optimal_feature]
feature_name_list = []
for name in dir_dict.keys():
feature_name_list.append(name)
logger.info('feature_list:%s', feature_name_list)
logger.info('train_feature:%s', np.array(feature_name_list)[optimal_feature])
logger.info('generating training set success!')
logger.info('progress bar: 80%')
# 训练模型
logger.info('PLS is training!')
pls = PLSRegression()
pls.fit(X_train, Y_train)
PLSRegression(copy=True, max_iter=1000, n_components=5, scale=True, tol=1e-06)
logger.info('train PLS success!')
# 预测
for path, name, n in zip(block_features_dir, block_features_name, range(len(block_features_dir))):
features_array = self.imageHandler.get_data(path)
X_test = np.reshape(features_array, (features_array.shape[0], features_array[0].size)).T
X_test = X_test[:, optimal_feature]
Y_test = pls.predict(X_test)
Y_test[Y_test < soil_salinity_value_min] = soil_salinity_value_min
Y_test[Y_test > soil_salinity_value_max] = soil_salinity_value_max
salinity_img = Y_test.reshape(features_array.shape[1], features_array.shape[2])
out_image = Image.fromarray(salinity_img)
suffix = '_' + name.split('_')[-4] + "_" + name.split('_')[-3] + "_" + name.split('_')[-2] + "_" + \
name.split('_')[-1]
out_path = self.__workspace_block_tif_processed_path + 'salinity\\' + 'salinity' + suffix
if not os.path.exists(self.__workspace_block_tif_processed_path + 'salinity\\'):
os.makedirs(self.__workspace_block_tif_processed_path + 'salinity\\')
out_image.save(out_path)
# logger.info('total:%s,block:%s test data success!', len(block_features_dir), n)
logger.info('test data success!')
# 合并预测后的影像
data_dir = self.__workspace_block_tif_processed_path + 'salinity\\'
out_path = self.__workspace_processing_path[0:-1]
bp.combine(data_dir, cols, rows, out_path, file_type=['tif'], datetype='float32')
# l1a图像坐标转换地理坐标
salinity_path = self.__workspace_processing_path + "salinity.tif"
SrcImageName = os.path.split(self.__input_paras["AHV"]['ParaValue'])[1].split('.tar.gz')[0] + tar + '.tif'
salinity_geo_path = os.path.join(self.__workspace_processing_path, SrcImageName)
self.calInterpolation_bil_Wgs84_rc_sar_sigma(self.__processing_paras['paraMeter'], self.__preprocessed_paras['sim_ori'], salinity_path, salinity_geo_path)
# self.inter_Range2Geo(self.__preprocessed_paras['ori_sim'], salinity_path, salinity_geo_path, pixelspace)
# self._tr.l1a_2_geo(self.__preprocessed_paras['ori_sim'], salinity_path, salinity_geo_path)
self.resampleImgs(salinity_geo_path)
# 生成roi区域
product_path = os.path.join(self.__product_dic, SrcImageName)
roi.cal_roi(product_path, salinity_geo_path, self.create_roi(), background_value=np.nan)
# 生成快视图
self.imageHandler.write_quick_view(product_path)
self.create_meta_file(product_path)
file.make_targz(self.__out_para, self.__product_dic)
logger.info('process_handle success!')
logger.info('progress bar: 100%')
if __name__ == '__main__':
multiprocessing.freeze_support()
start = datetime.datetime.now()
try:
if len(sys.argv) < 2:
xml_path = EXE_NAME + '.xml'
else:
xml_path = sys.argv[1]
main_handler = SalinityMain(xml_path)
if main_handler.check_source() is False:
raise Exception('check_source() failed!')
if main_handler.preprocess_handle() is False:
raise Exception('preprocess_handle() failed!')
if main_handler.process_handle(start) is False:
raise Exception('process_handle() failed!')
logger.info('successful production of ' + EXE_NAME + ' products!')
except Exception:
logger.exception("run-time error!")
finally:
main_handler.del_temp_workspace()
pass
end = datetime.datetime.now()
msg = 'running use time: %s ' % (end - start)
logger.info(msg)