更新叶面积指数模型训练功能
parent
db14c51ee7
commit
7834ac894a
Binary file not shown.
|
@ -0,0 +1,166 @@
|
|||
#
|
||||
# 模型计算的库
|
||||
#
|
||||
import cython
|
||||
cimport cython # 必须导入
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
from libc.math cimport pi
|
||||
from scipy.optimize import leastsq
|
||||
import random
|
||||
import logging
|
||||
logger = logging.getLogger("mylog")
|
||||
|
||||
|
||||
def WMCModel(param_arr,sample_lai,sample_soil,sample_inc,sample_sigma):
|
||||
""" WMC模型 增加 归一化植被指数
|
||||
|
||||
Args:
|
||||
param_arr (np.ndarray): 参数数组
|
||||
sample_lai (double): 叶面积指数
|
||||
sample_soil (double): 土壤含水量
|
||||
sample_inc (double): 入射角(弧度值)
|
||||
sample_sigma (double): 后向散射系数(线性值)
|
||||
|
||||
Returns:
|
||||
double: 方程值
|
||||
"""
|
||||
# 映射参数,方便修改模型
|
||||
A,B,C,D,M,N=param_arr # 在这里修改模型
|
||||
V_lai=sample_lai
|
||||
#V_lai=E*sample_lai+F
|
||||
exp_gamma=np.exp(-2*B*((V_lai*D+C))*(1/np.cos(sample_inc)))
|
||||
sigma_soil=M*sample_soil+N
|
||||
sigma_veg=A*((V_lai))*np.cos(sample_inc)
|
||||
f_veg=1
|
||||
result=sigma_veg*(1-exp_gamma)+sigma_soil*exp_gamma-sample_sigma
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
def train_WMCmodel(lai_water_inc_sigma_list,params_X0,train_err_image_path,draw_flag=True):
|
||||
""" 训练模型参数
|
||||
|
||||
Args:
|
||||
lai_waiter_inc_sigma_list (list): 训练模型使用的样本呢
|
||||
"""
|
||||
def f(X):
|
||||
eqs=[]
|
||||
for lai_water_inc_sigma_item in lai_water_inc_sigma_list:
|
||||
sample_lai=lai_water_inc_sigma_item[4]
|
||||
sample_sigma=lai_water_inc_sigma_item[5] # 5: csv_sigma, 8:tiff_sigma
|
||||
sample_soil=lai_water_inc_sigma_item[6]
|
||||
sample_inc=lai_water_inc_sigma_item[7]
|
||||
FVC=lai_water_inc_sigma_item[8]
|
||||
eqs.append(WMCModel(X,sample_lai,sample_soil,sample_inc,sample_sigma))
|
||||
return eqs
|
||||
|
||||
X0 = params_X0 # 初始值
|
||||
# logger.info(str(X0))
|
||||
h = leastsq(f, X0)
|
||||
# logger.info(h[0],h[1])
|
||||
err_f=f(h[0])
|
||||
x_arr=[lai_waiter_inc_sigma_item[4] for lai_waiter_inc_sigma_item in lai_water_inc_sigma_list]
|
||||
# 根据误差大小进行排序
|
||||
# logger.info("训练集:\n根据误差输出点序\n数量:{}\n点序\t误差值\t 样点信息".format(str(np.array(err_f).shape)))
|
||||
# for i in np.argsort(np.array(err_f)):
|
||||
# logger.info('{}\t{}\t{}'.format(i,err_f[i],str(lai_water_inc_sigma_list[i])))
|
||||
# logger.info("\n误差点序输出结束\n")
|
||||
|
||||
if draw_flag:
|
||||
# logger.info(err_f)
|
||||
# logger.info(np.where(np.abs(err_f)<10))
|
||||
from matplotlib import pyplot as plt
|
||||
plt.scatter(x_arr,err_f)
|
||||
plt.title("equation-err")
|
||||
plt.savefig(train_err_image_path,dpi=600)
|
||||
plt.show()
|
||||
|
||||
return h[0]
|
||||
|
||||
def test_WMCModel(lai_waiter_inc_sigma_list,param_arr,lai_X0,test_err_image_path,draw_flag=True):
|
||||
""" 测试模型训练结果
|
||||
|
||||
Args:
|
||||
lai_waiter_inc_sigma_list (list): 测试使用的样本集
|
||||
A (_type_): 参数A
|
||||
B (_type_): 参数B
|
||||
C (_type_): 参数C
|
||||
D (_type_): 参数D
|
||||
M (_type_): 参数M
|
||||
N (_type_): 参数N
|
||||
lai_X0 (_type_): 初始值
|
||||
|
||||
Returns:
|
||||
list: 误差列表 [sample_lai,err,predict]
|
||||
"""
|
||||
err=[]
|
||||
err_f=[]
|
||||
x_arr=[]
|
||||
err_lai=[]
|
||||
for lai_waiter_inc_sigma_item in lai_waiter_inc_sigma_list:
|
||||
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=lai_waiter_inc_sigma_item
|
||||
def f(X):
|
||||
lai=X[0]
|
||||
eqs=[WMCModel(param_arr,lai,sample_soil,sample_inc,csv_sigma)]
|
||||
return eqs
|
||||
X0=lai_X0
|
||||
h = leastsq(f, X0)
|
||||
temp_err=h[0]-sample_lai
|
||||
err_lai.append(temp_err[0]) # lai预测的插值
|
||||
err.append([sample_lai,temp_err[0],h[0][0],sample_code])
|
||||
err_f.append(f(h[0])[0]) # 方程差
|
||||
x_arr.append(sample_lai)
|
||||
|
||||
# 根据误差大小进行排序
|
||||
# logger.info("测试集:\n根据误差输出点序\n数量:{}\n点序\t误差值\t 方程差\t样点信息".format(str(np.array(err_lai).shape)))
|
||||
# for i in np.argsort(np.array(err_lai)):
|
||||
# logger.info('{}\t{}\t{}\t{}'.format(i,err_lai[i],err_f[i],str(lai_waiter_inc_sigma_list[i])))
|
||||
# logger.info("\n误差点序输出结束\n")
|
||||
|
||||
if draw_flag:
|
||||
from matplotlib import pyplot as plt
|
||||
plt.scatter(x_arr,err_lai)
|
||||
plt.title("equation-err")
|
||||
plt.savefig(test_err_image_path,dpi=600)
|
||||
plt.show()
|
||||
return err
|
||||
|
||||
def processs_WMCModel(param_arr,lai_X0,sigma,inc_angle,soil_water):
|
||||
|
||||
if(sigma<0 ):
|
||||
return np.nan
|
||||
def f(X):
|
||||
lai=X[0]
|
||||
eqs=[WMCModel(param_arr,lai,soil_water,inc_angle,sigma )]
|
||||
return eqs
|
||||
h = leastsq(f, [lai_X0])
|
||||
|
||||
return h[0][0]
|
||||
|
||||
# Cython 的扩展地址
|
||||
cpdef np.ndarray[double,ndim=2] process_tiff(np.ndarray[double,ndim=2] sigma_tiff,
|
||||
np.ndarray[double,ndim=2] inc_tiff,
|
||||
np.ndarray[double,ndim=2] soil_water_tiff,
|
||||
np.ndarray[double,ndim=1] param_arr,
|
||||
double lai_X0):
|
||||
|
||||
cdef np.ndarray[double,ndim=2] result=sigma_tiff
|
||||
cdef int param_arr_length=param_arr.shape[0]
|
||||
cdef int height=sigma_tiff.shape[0]
|
||||
cdef int width=sigma_tiff.shape[1]
|
||||
cdef int i=0
|
||||
cdef int j=0
|
||||
cdef double temp=0
|
||||
|
||||
while i<height:
|
||||
j=0
|
||||
while j<width:
|
||||
temp = processs_WMCModel(param_arr,lai_X0,sigma_tiff[i,j],inc_tiff[i,j],soil_water_tiff[i,j])
|
||||
temp=temp if temp<10 and temp>=0 else np.nan
|
||||
result[i,j]=temp
|
||||
j=j+1
|
||||
i=i+1
|
||||
return result
|
||||
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,45 @@
|
|||
from setuptools import setup
|
||||
from setuptools.extension import Extension
|
||||
from Cython.Distutils import build_ext
|
||||
from Cython.Build import cythonize
|
||||
import numpy
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
|
||||
class MyBuildExt(build_ext):
|
||||
def run(self):
|
||||
build_ext.run(self)
|
||||
|
||||
build_dir = Path(self.build_lib)
|
||||
root_dir = Path(__file__).parent
|
||||
target_dir = build_dir if not self.inplace else root_dir
|
||||
|
||||
self.copy_file(Path('./LAIProcess') / '__init__.py', root_dir, target_dir)
|
||||
#self.copy_file(Path('./pkg2') / '__init__.py', root_dir, target_dir)
|
||||
self.copy_file(Path('.') / '__init__.py', root_dir, target_dir)
|
||||
def copy_file(self, path, source_dir, destination_dir):
|
||||
if not (source_dir / path).exists():
|
||||
return
|
||||
shutil.copyfile(str(source_dir / path), str(destination_dir / path))
|
||||
|
||||
setup(
|
||||
name="MyModule",
|
||||
ext_modules=cythonize(
|
||||
[
|
||||
#Extension("pkg1.*", ["root/pkg1/*.py"]),
|
||||
Extension("pkg2.*", ["./LAIProcess.pyx"]),
|
||||
#Extension("1.*", ["root/*.py"])
|
||||
],
|
||||
build_dir="build",
|
||||
compiler_directives=dict(
|
||||
always_allow_keywords=True
|
||||
)),
|
||||
cmdclass=dict(
|
||||
build_ext=MyBuildExt
|
||||
),
|
||||
packages=[],
|
||||
include_dirs=[numpy.get_include()],
|
||||
)
|
||||
|
||||
# 指令: python setup.py build_ext --inplace
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,135 @@
|
|||
from xml.etree.ElementTree import ElementTree
|
||||
import os
|
||||
|
||||
|
||||
class DictXml:
|
||||
def __init__(self, xml_path):
|
||||
self.xml_path = xml_path
|
||||
self.__tree = ElementTree()
|
||||
self.__root = None
|
||||
self.init_xml()
|
||||
|
||||
def init_xml(self):
|
||||
self.__root = self.__tree.parse(self.xml_path)
|
||||
if self.__root is None:
|
||||
raise Exception("get root failed")
|
||||
|
||||
def get_extend(self):
|
||||
productInfo = self.__root.find("imageinfo")
|
||||
if productInfo is None:
|
||||
raise Exception("get imageInfo failed")
|
||||
|
||||
corner = productInfo.find("corner")
|
||||
if corner is None:
|
||||
raise Exception("get corner failed")
|
||||
|
||||
topLeft = corner.find("topLeft")
|
||||
if topLeft is None:
|
||||
raise Exception("get topLeft failed")
|
||||
|
||||
topRight = corner.find("topRight")
|
||||
if topRight is None:
|
||||
raise Exception("get topRight failed")
|
||||
|
||||
bottomLeft = corner.find("bottomLeft")
|
||||
if bottomLeft is None:
|
||||
raise Exception("get bottomLeft failed")
|
||||
|
||||
bottomRight = corner.find("bottomRight")
|
||||
if bottomRight is None:
|
||||
raise Exception("get bottomRight failed")
|
||||
|
||||
point_upleft = [float(topLeft.find("longitude").text), float(topLeft.find("latitude").text)]
|
||||
point_upright = [float(topRight.find("longitude").text), float(topRight.find("latitude").text)]
|
||||
point_downleft = [float(bottomLeft.find("longitude").text), float(bottomLeft.find("latitude").text)]
|
||||
point_downright = [float(bottomRight.find("longitude").text), float(bottomRight.find("latitude").text)]
|
||||
scopes = [point_upleft, point_upright, point_downleft, point_downright]
|
||||
|
||||
point_upleft_buf = [float(topLeft.find("longitude").text) - 0.5, float(topLeft.find("latitude").text) + 0.5]
|
||||
point_upright_buf = [float(topRight.find("longitude").text) + 0.5, float(topRight.find("latitude").text) + 0.5]
|
||||
point_downleft_buf = [float(bottomLeft.find("longitude").text) - 0.5,
|
||||
float(bottomLeft.find("latitude").text) - 0.5]
|
||||
point_downright_buf = [float(bottomRight.find("longitude").text) + 0.5,
|
||||
float(bottomRight.find("latitude").text) - 0.5]
|
||||
scopes_buf = [point_upleft_buf, point_upright_buf, point_downleft_buf, point_downright_buf]
|
||||
return scopes
|
||||
|
||||
|
||||
class xml_extend:
|
||||
def __init__(self, xml_path):
|
||||
self.xml_path = xml_path
|
||||
self.__tree = ElementTree()
|
||||
self.__root = None
|
||||
self.init_xml()
|
||||
|
||||
def init_xml(self):
|
||||
self.__root = self.__tree.parse(self.xml_path)
|
||||
if self.__root is None:
|
||||
raise Exception("get root failed")
|
||||
|
||||
def get_extend(self):
|
||||
ProductBasicInfo = self.__root.find("ProductBasicInfo")
|
||||
if ProductBasicInfo is None:
|
||||
raise Exception("get ProductBasicInfo failed")
|
||||
|
||||
SpatialCoverageInformation = ProductBasicInfo.find("SpatialCoverageInformation")
|
||||
if SpatialCoverageInformation is None:
|
||||
raise Exception("get SpatialCoverageInformation failed")
|
||||
|
||||
TopLeftLongitude = SpatialCoverageInformation.find("TopLeftLongitude")
|
||||
if TopLeftLongitude is None:
|
||||
raise Exception("get TopLeftLongitude failed")
|
||||
|
||||
TopLeftLatitude = SpatialCoverageInformation.find("TopLeftLatitude")
|
||||
if TopLeftLatitude is None:
|
||||
raise Exception("get TopLeftLatitude failed")
|
||||
|
||||
TopRightLongitude = SpatialCoverageInformation.find("TopRightLongitude")
|
||||
if TopRightLongitude is None:
|
||||
raise Exception("get TopRightLongitude failed")
|
||||
|
||||
TopRightLatitude = SpatialCoverageInformation.find("TopRightLatitude")
|
||||
if TopRightLatitude is None:
|
||||
raise Exception("get TopRightLatitude failed")
|
||||
|
||||
BottomRightLongitude = SpatialCoverageInformation.find("BottomRightLongitude")
|
||||
if BottomRightLongitude is None:
|
||||
raise Exception("get BottomRightLongitude failed")
|
||||
|
||||
BottomRightLatitude = SpatialCoverageInformation.find("BottomRightLatitude")
|
||||
if BottomRightLatitude is None:
|
||||
raise Exception("get BottomRightLatitude failed")
|
||||
|
||||
BottomLeftLongitude = SpatialCoverageInformation.find("BottomLeftLongitude")
|
||||
if BottomLeftLongitude is None:
|
||||
raise Exception("get BottomLeftLongitude failed")
|
||||
|
||||
BottomLeftLatitude = SpatialCoverageInformation.find("BottomLeftLatitude")
|
||||
if BottomLeftLatitude is None:
|
||||
raise Exception("get BottomLeftLatitude failed")
|
||||
|
||||
point_upleft = [float(TopLeftLongitude.text), float(TopLeftLatitude.text)]
|
||||
point_upright = [float(TopRightLongitude.text), float(TopRightLatitude.text)]
|
||||
point_downleft = [float(BottomLeftLongitude.text), float(BottomLeftLatitude.text)]
|
||||
point_downright = [float(BottomRightLongitude.text), float(BottomRightLatitude.text)]
|
||||
scopes = [point_upleft, point_upright, point_downleft, point_downright]
|
||||
|
||||
point_upleft_buf = [float(TopLeftLongitude.text) - 0.5, float(TopLeftLatitude.text) + 0.5]
|
||||
point_upright_buf = [float(TopRightLongitude.text) + 0.5, float(TopRightLatitude.text) + 0.5]
|
||||
point_downleft_buf = [float(BottomLeftLongitude.text) - 0.5, float(BottomLeftLatitude.text) - 0.5]
|
||||
point_downright_buf = [float(BottomRightLongitude.text) + 0.5, float(BottomRightLatitude.text) - 0.5]
|
||||
scopes_buf = [point_upleft_buf, point_upright_buf, point_downleft_buf, point_downright_buf]
|
||||
return scopes
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
xml_path = r'E:\MicroWorkspace\GF3A_nanjing\GF3_SAY_QPSI_011444_E118.9_N31.4_20181012_L1A_AHV_L10003515422\GF3_SAY_QPSI_011444_E118.9_N31.4_20181012_L1A_AHV_L10003515422.meta.xml'
|
||||
scopes, scopes_buf = DictXml(xml_path).get_extend()
|
||||
print(scopes)
|
||||
print(scopes_buf)
|
||||
# path = r'D:\BaiduNetdiskDownload\GZ\lon.rdr'
|
||||
# path2 = r'D:\BaiduNetdiskDownload\GZ\lat.rdr'
|
||||
# path3 = r'D:\BaiduNetdiskDownload\GZ\lon_lat.tif'
|
||||
# s = ImageHandler().band_merge(path, path2, path3)
|
||||
# print(s)
|
||||
# pass
|
Loading…
Reference in New Issue