microproduct-l-sar/soilSalinity-Train_predict/SoilSalinityTrainMain.py

419 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: UTF-8 -*-
"""
@Project microproduct
@File SoilSalinityMain.py
@Author SHJ
@Contact土壤盐碱度算法主函数
@Date 2021/9/6
@Version 1.0.0
修改历史:
[修改序列] [修改日期] [修改者] [修改内容]
1 2022-6-27 石海军 1.增加配置文件config.ini; 2.内部处理使用地理坐标系(4326)
"""
import glob
import logging
import shutil
import pickle
from tool.algorithm.algtools.MetaDataHandler import Calibration
from tool.algorithm.algtools.PreProcess import PreProcess as pp
from tool.algorithm.image.ImageHandle import ImageHandler
from tool.algorithm.polsarpro.pspLeeRefinedFilterT3 import LeeRefinedFilterT3
from tool.algorithm.xml.AlgXmlHandle import ManageAlgXML, CheckSource, InitPara
from tool.algorithm.algtools.logHandler import LogHandler
from tool.algorithm.xml.AnalysisXml import xml_extend
from tool.algorithm.xml.CreateMetaDict import CreateMetaDict, CreateProductXml
from tool.file.fileHandle import fileHandle
from tool.algorithm.polsarpro.AHVToPolsarpro import AHVToPolsarpro
from pspHAAlphaDecomposition import PspHAAlphaDecomposition
import scipy.spatial.transform
import scipy.spatial.transform._rotation_groups # 用于解决打包错误
import scipy.special.cython_special # 用于解决打包错误
from sklearn.cross_decomposition import PLSRegression
from scipy.stats import linregress
import os
import datetime
import numpy as np
import sys
import json
from tool.config.ConfigeHandle import Config as cf
from tool.csv.csvHandle import csvHandle
from tool.algorithm.transforml1a.transHandle import TransImgL1A
from tool.algorithm.ml.machineLearning import MachineLeaning as ml
import multiprocessing
pNum = cf.get('features')
optimal_feature = []
if pNum == 'all':
for i in range(54):
optimal_feature.append((i))
else:
featurs = pNum.split(',')
for n in featurs:
optimal_feature.append(int(n))
print(optimal_feature)
csvh = csvHandle()
soil_salinity_value_min = float(cf.get('product_value_min'))
soil_salinity_value_max = float(cf.get('product_value_max'))
pixelspace=float(cf.get('pixelspace'))
tar = r'-' + cf.get('tar')
productLevel = cf.get('productLevel')
split_num = int(cf.get('split_num'))
if cf.get('debug') == 'True':
DEBUG = True
else:
DEBUG = False
EXE_NAME = cf.get('train_exe_name')
# env_str = os.path.split(os.path.realpath(__file__))[0]
env_str =os.path.dirname(os.path.abspath(sys.argv[0]))
os.environ['PROJ_LIB'] = env_str
LogHandler.init_log_handler('run_log\\'+EXE_NAME)
logger = logging.getLogger("mylog")
file = fileHandle(DEBUG)
class SalinityMain:
"""
算法主函数
"""
def __init__(self, alg_xml_path):
self.alg_xml_path = alg_xml_path
self.imageHandler = ImageHandler()
self.__alg_xml_handler = ManageAlgXML(alg_xml_path)
self.__check_handler = CheckSource(self.__alg_xml_handler)
self.__workspace_path, self.__out_para = None, None
self.__input_paras, self.__output_paras, self.__processing_paras, self.__preprocessed_paras = {}, {}, {}, {}
self.__feature_name_list = []
# 参考影像路径,坐标系
self.__ref_img_path, self.__proj = '', ''
# 宽/列数,高/行数
self.__cols, self.__rows = 0, 0
# 影像投影变换矩阵
self.__geo = [0, 0, 0, 0, 0, 0]
def check_source(self):
"""
检查算法相关的配置文件,图像,辅助文件是否齐全
"""
env_str = os.getcwd()
logger.info("sysdir: %s", env_str)
self.__check_handler.check_alg_xml()
self.__check_handler.check_run_env()
# 检查景影像是否为全极化
self.__input_paras = self.__alg_xml_handler.get_input_paras()
if self.__check_handler.check_input_paras(self.__input_paras) is False:
return False
self.__workspace_path = self.__alg_xml_handler.get_workspace_path()
self.__create_work_space()
self.__processing_paras = InitPara.init_processing_paras(self.__input_paras, self.__workspace_preprocessing_path)
self.__processing_paras.update(InitPara(DEBUG).get_mult_tar_gz_infs(self.__processing_paras, self.__workspace_preprocessing_path))
SrcImageName = os.path.split(self.__input_paras["AHV"]['ParaValue'])[1].split('.tar.gz')[0]
result_name = SrcImageName + tar + ".zip"
self.__out_para = os.path.join(self.__workspace_path, EXE_NAME, 'Output', result_name)
self.__alg_xml_handler.write_out_para("SoilSalinityProduct", self.__out_para) #写入输出参数
logger.info('check_source success!')
logger.info('progress bar: 10%')
return True
def get_tar_gz_inf(self, tar_gz_path):
para_dic = {}
name = os.path.split(tar_gz_path)[1].rstrip('.tar.gz')
file_dir = os.path.join(self.__workspace_preprocessing_path, name + '\\')
file.de_targz(tar_gz_path, file_dir)
# 元文件字典
# para_dic.update(InitPara.get_meta_dic(InitPara.get_meta_paths(file_dir, name), name))
para_dic.update(InitPara.get_meta_dic_new(InitPara.get_meta_paths(file_dir, name), name))
# tif路径字典
para_dic.update(InitPara.get_polarization_mode(InitPara.get_tif_paths(file_dir, name)))
parameter_path = os.path.join(file_dir, "orth_para.txt")
para_dic.update({"paraMeter": parameter_path})
return para_dic
def __create_work_space(self):
"""
删除原有工作区文件夹,创建新工作区文件夹
"""
self.__workspace_preprocessing_path = self.__workspace_path + EXE_NAME +'\\Temporary\\preprocessing\\'
self.__workspace_preprocessed_path = self.__workspace_path + EXE_NAME + '\\Temporary\\preprocessed\\'
self.__workspace_processing_path = self.__workspace_path + EXE_NAME + '\\Temporary\\processing\\'
self.__workspace_block_tif_path = self.__workspace_path + EXE_NAME + '\\Temporary\\blockTif\\'
self.__workspace_block_tif_processed_path = self.__workspace_path + EXE_NAME + '\\Temporary\\blockTifProcessed\\'
self.__product_dic = self.__workspace_processing_path + 'product\\'
path_list = [self.__workspace_preprocessing_path, self.__workspace_preprocessed_path,
self.__workspace_processing_path, self.__workspace_block_tif_path,
self.__workspace_block_tif_processed_path,self.__product_dic]
file.creat_dirs(path_list)
logger.info('create new workspace success!')
def del_temp_workspace(self):
"""
临时工作区
"""
if DEBUG is True:
return
path = self.__workspace_path + EXE_NAME + r"\Temporary"
if os.path.exists(path):
file.del_folder(path)
def preprocess_handle(self):
"""
预处理
"""
X_train = []
Y_train = []
for name in self.__processing_paras['name_list']:
p = pp()
#计算roi
scopes = ()
scopes += (xml_extend(self.__processing_paras[name + '_META']).get_extend(),)
scopes += p.box2scope(self.__processing_paras['box'])
scopes_roi = p.intersect_polygon(scopes)
para_names_l1a = [name + "_HH", name + "_VV", name + "_HV", name + "_VH"]
self.l1a_width = ImageHandler.get_img_width(self.__processing_paras[name + '_HH'])
self.l1a_height = ImageHandler.get_img_height(self.__processing_paras[name + '_HH'])
self._tr = TransImgL1A(self.__processing_paras[name + '_sim_ori'], scopes_roi, self.l1a_height, self.l1a_width) # 裁剪图像
out_path = os.path.join(self.__workspace_preprocessed_path, name)
if not os.path.exists(out_path):
os.makedirs(out_path)
for names in para_names_l1a:
pred_path = os.path.join(out_path, names.split('-')[1]+'.tif')
shutil.copy(self.__processing_paras[names], pred_path)
self.__preprocessed_paras.update({names: pred_path})
self.l1a_width = ImageHandler.get_img_width(self.__preprocessed_paras[name + '_HH'])
self.l1a_height = ImageHandler.get_img_height(self.__preprocessed_paras[name + '_HH'])
measured_data_img = self._tr.tran_lonlats_to_L1A_rowcols(
csvh.readcsv(self.__processing_paras['MeasuredData']),
self.__processing_paras[name + '_sim_ori'], self.l1a_height,
self.l1a_width)
# 极化分解得到T3矩阵
out_dir = os.path.join(self.__workspace_processing_path, name.split('-')[0])
if not os.path.exists(out_dir):
os.makedirs(out_dir)
feature_path = self.AHVToPolsarpro(out_dir, name)
features_tif = list(glob.glob(os.path.join(feature_path, '*.tif')))
features_array = np.zeros((len(features_tif), self.l1a_height, self.l1a_width), dtype='float32')
# for i, tif in zip(range(len(features_tif)), features_tif):
# features_array[i, :, :] = self.imageHandler.get_band_array(tif, 1)
# X_train_part, Y_train_part = ml.get_train_data(features_array, measured_data_img)
feature_arr = np.zeros((len(measured_data_img), len(features_tif)), dtype=np.float32)
for i in range(len(features_tif)):
featureArr = self.imageHandler.get_band_array(features_tif[i], 1)
for j, data in zip(range(len(measured_data_img)), measured_data_img):
row = data[0]
col = data[1]
value = featureArr[row, col]
if not np.isnan(value) or np.isinf(value):
feature_arr[j, i] = value
Y_train.append([data[2]])
X_train_part = feature_arr
Y_train_part = np.array(Y_train[:len(measured_data_img)])
# optimal_features = ml.sel_optimal_feature_set(X_train_part, Y_train_part, threshold=0.01)
# optimal_features = ml.remove_correlation_feature(X_train_part, optimal_feature, threshold=0.85)
X_train_part = X_train_part[:, list(optimal_feature)]
if X_train == []:
X_train = X_train_part
Y_train = Y_train_part
else:
X_train = np.vstack((X_train, X_train_part))
Y_train = np.vstack((Y_train, Y_train_part))
mid = int((X_train.shape[0])/split_num) # 将样本集分成1:1
X_train_in=X_train[:mid,:]
Y_train_in=Y_train[:mid,:]
X_train_out=X_train[mid:,:]
Y_train_out=Y_train[mid:,:]
#
# 训练模型
logger.info('PLS is training!')
pls = PLSRegression()
pls.fit(X_train, Y_train)
plsMap = {}
x_data = json.dumps(X_train.tolist())
y_data = json.dumps(Y_train.tolist())
plsMap.update({'x_data': x_data})
plsMap.update({'y_data': y_data})
PLSRegression(copy=True, max_iter=1000, n_components=5, scale=True, tol=1e-06)
# #
print("---- Train --------------------------------------------")
Y_train_pred_in = pls.predict(X_train_in)
print('Y_train_pred-----------------------------------------')
print(Y_train_pred_in.reshape(-1))
print('Y_train_pred-----------------------------------------')
print('Y_train----------------------------------------------')
print(Y_train_in.reshape(-1))
print('Y_train----------------------------------------------')
slope, intercept, r_value, p_value, std_err = linregress(Y_train_pred_in.reshape(-1), Y_train_in.reshape(-1))
R2 = r_value ** 2
print("训练 使用scipy库a", slope, "b", intercept, "r", r_value, "r-squared", R2)
print("-----------------------------------------------------")
print("---- Test --------------------------------------------")
Y_train_pred_out = pls.predict(X_train_out)
print('Y_train_pred-----------------------------------------')
print(Y_train_pred_out.reshape(-1))
print('Y_train_pred-----------------------------------------')
print('Y_train----------------------------------------------')
print(Y_train_out.reshape(-1))
print('Y_train----------------------------------------------')
slope, intercept, r_value, p_value, std_err = linregress(Y_train_pred_out.reshape(-1), Y_train_out.reshape(-1))
R2 = r_value ** 2
print("测试 使用scipy库a", slope, "b", intercept, "r", r_value, "r-squared", R2)
print("-----------------------------------------------------")
SrcImageName = os.path.split(self.__input_paras["AHV"]['ParaValue'])[1].split('.tar.gz')[0]
model_path = os.path.join(self.__product_dic, SrcImageName + tar + '.pkl')
with open(model_path, 'wb')as pkl:
pickle.dump(pls, pkl)
json_path = os.path.join(self.__product_dic, SrcImageName + tar + '.json')
with open(json_path, 'w') as js:
json.dump(plsMap, js)
logger.info('train PLS success!')
file.make_zip(self.__out_para, model_path)
logger.info('process_handle success!')
logger.info('progress bar: 100%')
def AHVToPolsarpro(self,out_dir, name):
atp = AHVToPolsarpro()
ahv_path = self.__workspace_preprocessed_path
t3_path = os.path.join(out_dir, 'psp_t3')
features_path = os.path.join(out_dir, 'psp_haalpha')
# atp.ahv_to_polsarpro_t3_soil(t3_path, ahv_path)
polarization = ['HH', 'HV', 'VH', 'VV']
calibration = Calibration.get_Calibration_coefficient(self.__processing_paras[name + '_Origin_META'], polarization)
tif_path = atp.calibration(calibration, in_ahv_dir=os.path.join(self.__workspace_preprocessed_path, name))
atp.ahv_to_polsarpro_t3_soil(t3_path, tif_path)
logger.info('ahv transform to polsarpro T3 matrix success!')
logger.info('progress bar: 20%')
# Lee滤波
leeFilter = LeeRefinedFilterT3()
lee_filter_path = os.path.join(out_dir, 'lee_filter\\')
leeFilter.api_lee_refined_filter_T3('', t3_path, lee_filter_path, 0, 0, atp.rows(), atp.cols())
logger.info('Refined_lee process success!')
haa = PspHAAlphaDecomposition(normalization=True)
haa.api_creat_h_a_alpha_features(h_a_alpha_out_dir=features_path,
h_a_alpha_decomposition_T3_path='h_a_alpha_decomposition_T3.exe' ,
h_a_alpha_eigenvalue_set_T3_path='h_a_alpha_eigenvalue_set_T3.exe' ,
h_a_alpha_eigenvector_set_T3_path='h_a_alpha_eigenvector_set_T3.exe',
polsarpro_in_dir=lee_filter_path)
return features_path
def create_meta_file(self, product_path):
xml_path = "./model_meta.xml"
tem_folder = self.__workspace_path + EXE_NAME + r"\Temporary""\\"
image_path = product_path
out_path1 = os.path.join(tem_folder, "trans_geo_projcs.tif")
out_path2 = os.path.join(tem_folder, "trans_projcs_geo.tif")
# par_dict = CreateDict(image_path, [1, 1, 1, 1], out_path1, out_path2).calu_nature(start)
# model_xml_path = os.path.join(tem_folder, "creat_standard.meta.xml") # 输出xml路径
# CreateStadardXmlFile(xml_path, self.alg_xml_path, par_dict, model_xml_path).create_standard_xml()
# 文件夹打包
SrcImagePath = self.__input_paras["AHV"]['ParaValue']
paths = SrcImagePath.split(';')
SrcImageName = os.path.split(paths[0])[1].split('.tar.gz')[0]
# if len(paths) >= 2:
# for i in range(1, len(paths)):
# SrcImageName = SrcImageName + ";" + os.path.split(paths[i])[1].split('.tar.gz')[0]
# meta_xml_path = self.__product_dic + EXE_NAME + "Product.meta.xml"
# CreateMetafile(self.__processing_paras['META'], self.alg_xml_path, model_xml_path, meta_xml_path).process(
# SrcImageName)
model_path = "./product.xml"
meta_xml_path = os.path.join(self.__product_dic, SrcImageName + tar + ".meta.xml")
para_dict = CreateMetaDict(image_path, self.__processing_paras['Origin_META'], self.__workspace_processing_path,
out_path1, out_path2).calu_nature()
para_dict.update({"imageinfo_ProductName": "土壤盐碱度"})
para_dict.update({"imageinfo_ProductIdentifier": "SoilSalinity"})
para_dict.update({"imageinfo_ProductLevel": productLevel})
para_dict.update({"ProductProductionInfo_BandSelection": "1,2"})
CreateProductXml(para_dict, model_path, meta_xml_path).create_standard_xml()
temp_folder = os.path.join(self.__workspace_path, EXE_NAME, 'Output')
out_xml = os.path.join(temp_folder, os.path.basename(meta_xml_path))
if os.path.exists(temp_folder) is False:
os.mkdir(temp_folder)
# CreateProductXml(para_dict, model_path, out_xml).create_standard_xml()
shutil.copy(meta_xml_path, out_xml)
def calInterpolation_bil_Wgs84_rc_sar_sigma(self, parameter_path, dem_rc, in_sar, out_sar):
'''
# std::cout << "mode 11";
# std::cout << "SIMOrthoProgram.exe 11 in_parameter_path in_rc_wgs84_path in_ori_sar_path out_orth_sar_path";
'''
exe_path = r".\baseTool\x64\Release\SIMOrthoProgram.exe"
exe_cmd = r"set PROJ_LIB=.\baseTool\x64\Release; & {0} {1} {2} {3} {4} {5}".format(exe_path, 11, parameter_path,
dem_rc, in_sar, out_sar)
print(exe_cmd)
print(os.system(exe_cmd))
print("==========================================================================")
def process_handle(self, start):
"""
算法主处理函数
:return: True or False
"""
# 读取实测值,从经纬度坐标系转为图像坐标系
# measured_data_img = self._tr.tran_lonlats_to_L1A_rowcols(csvh.readcsv(self.__processing_paras['MeasuredData']), self.__processing_paras['ori_sim'])
if __name__ == '__main__':
multiprocessing.freeze_support()
start = datetime.datetime.now()
try:
if len(sys.argv) < 2:
xml_path = EXE_NAME + '.xml'
else:
xml_path = sys.argv[1]
main_handler = SalinityMain(xml_path)
if main_handler.check_source() is False:
raise Exception('check_source() failed!')
if main_handler.preprocess_handle() is False:
raise Exception('preprocess_handle() failed!')
# if main_handler.process_handle(start) is False:
# raise Exception('process_handle() failed!')
logger.info('successful production of ' + EXE_NAME + ' products!')
except Exception:
logger.exception("run-time error!")
finally:
main_handler.del_temp_workspace()
pass
end = datetime.datetime.now()
msg = 'running use time: %s ' % (end - start)
logger.info(msg)