microproduct/leafAreaIndex/sample_process.py

286 lines
10 KiB
Python
Raw Normal View History

2024-01-03 05:46:38 +00:00
#
# 样本处理的相关的库
#
2024-10-30 03:16:40 +00:00
import logging
2024-01-03 05:46:38 +00:00
from tool.algorithm.image.ImageHandle import ImageHandler
import math
import numpy as np
import random
import scipy
# 最小二乘求解非线性方程组
from scipy.optimize import leastsq,fsolve,root
from osgeo import gdal,gdalconst
import pandas as pds
from scipy import interpolate
from multiprocessing import pool
# 常量声明区域
imageHandler=ImageHandler()
2024-10-30 03:16:40 +00:00
logger = logging.getLogger("mylog")
2024-01-03 05:46:38 +00:00
# python 的函数类
def read_sample_csv(csv_path):
""" 读取样本的csv
Args:
csv_path (string): 样本csv的地址绝对路径
return:
[
['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],
['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],......
]
"""
lai_csv=pds.read_csv(csv_path)# 代码测试区域
lai_csv=lai_csv.loc[:,['id','lon','lat','leaf',"cal"]]
result=[]
for i in range(len(lai_csv)):
result.append([
0,
lai_csv.loc[i,'id'],
lai_csv.loc[i,'lon'], # lon,x
lai_csv.loc[i,'lat'], # lat,y
lai_csv.loc[i,'leaf'],
10**(float(lai_csv.loc[i,'cal'])/10),
])
return result
def read_tiff(tiff_path):
""" 从文件中读取影像
Args:
tiff_path (string): 文件影像路径
"""
im_proj, im_geotrans, im_arr=imageHandler.read_img(tiff_path)
return {
'proj':im_proj,
'geotrans':im_geotrans,
'data':im_arr
}
def ReprojectImages2(in_tiff_path,ref_tiff_path,out_tiff_path,resampleAlg=gdalconst.GRA_Bilinear):
""" 将输入影像重采样到参考影像的范围内
Args:
in_tiff_path (string): 输入影像
ref_tiff_path (string): 参考影像
out_tiff_path (string): 输出地址
resampleAlg (gadlconst): 插值方法
"""
# 若采用gdal.Warp()方法进行重采样
# 获取输出影像信息
inputrasfile = gdal.Open(in_tiff_path, gdal.GA_ReadOnly)
inputProj = inputrasfile.GetProjection()
# 获取参考影像信息
referencefile = gdal.Open(ref_tiff_path, gdal.GA_ReadOnly)
referencefileProj = referencefile.GetProjection()
referencefileTrans = referencefile.GetGeoTransform()
bandreferencefile = referencefile.GetRasterBand(1)
x = referencefile.RasterXSize
y = referencefile.RasterYSize
nbands = referencefile.RasterCount
# 创建重采样输出文件(设置投影及六参数)
driver = gdal.GetDriverByName('GTiff')
output = driver.Create(out_tiff_path, x, y, nbands, bandreferencefile.DataType)
output.SetGeoTransform(referencefileTrans)
output.SetProjection(referencefileProj)
options = gdal.WarpOptions(srcSRS=inputProj, dstSRS=referencefileProj, resampleAlg=gdalconst.GRA_Bilinear)
gdal.Warp(output, in_tiff_path, options=options)
def combine_sample_attr(sample_list,attr_tiff):
""" 构建样本
Args:
sample_list (list): 原样本
attr_tiff (string): 添加的属性数据
Returns:
list:[sample,new_attr]
"""
result=[]
# 因为soil_tiff 的影像的 影像分辨率较低
inv_gt=gdal.InvGeoTransform(attr_tiff['geotrans'])
for sample_item in sample_list:
sample_lon=sample_item[2]
sample_lat=sample_item[3]
sample_in_tiff_x=inv_gt[0]+inv_gt[1]*sample_lon+inv_gt[2]*sample_lat # x
sample_in_tiff_y=inv_gt[3]+inv_gt[4]*sample_lon+inv_gt[5]*sample_lat # y
x_min=int(np.floor(sample_in_tiff_x))
x_max=int(np.ceil(sample_in_tiff_x))
y_min=int(np.floor(sample_in_tiff_y))
y_max=int(np.ceil(sample_in_tiff_y))
if x_min<0 or y_min<0 or x_max>=attr_tiff['data'].shape[1] or y_max>=attr_tiff['data'].shape[0]:
continue
#
"""
f = interpolate.interp2d([0,0,1,1], [0,1,1,0],
[attr_tiff['data'][y_min,x_min],
attr_tiff['data'][y_max,x_min],
attr_tiff['data'][y_max,x_max],
attr_tiff['data'][y_min,x_min]
], kind='linear')
interp_value=f(sample_in_tiff_x-x_min,sample_in_tiff_y-y_min)
sample_item.append(interp_value[0])
"""
# 9x9
x_min=x_min-4 if x_min-9>=0 else 0
y_min=y_min-4 if y_min-9>=0 else 0
x_max=x_max+4 if x_max+4<attr_tiff['data'].shape[1] else attr_tiff['data'].shape[1]
y_max=y_max+4 if y_max+4<attr_tiff['data'].shape[0] else attr_tiff['data'].shape[0]
interp_value=np.mean(attr_tiff['data'][y_min:y_max,x_min:x_max])
sample_item.append(interp_value)
result.append(sample_item)
return result
def check_sample(sample_list):
""" 检查样本值
Args:
sample_list (list): 样本值[ ['日期', '样方编号', '经度', '纬度', 'LAI','土壤含水量','入射角','后向散射系数'] ]
Returns:
list : 处理之后的样本值
"""
result=[]
for item in sample_list:
if len(item)==10:
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
else:
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
if sample_sigma<=0 or np.isnan(sample_sigma):
2024-01-03 05:46:38 +00:00
continue
if (sample_inc*180/np.pi)>90 or np.isnan(sample_inc):
2024-01-03 05:46:38 +00:00
continue
if sample_soil<=0 or sample_soil>=1 or np.isnan(sample_soil):
2024-01-03 05:46:38 +00:00
continue
if sample_lai<=0 or sample_lai>=20 or np.isnan(sample_lai):
2024-01-03 05:46:38 +00:00
continue
result.append(item)
# 绘制分布图
# lai=[]
# sigma=[]
# csv_sigmas=[]
# text_label=[]
# for item in result:
# if len(item)==10:
# sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
# else:
# sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
# text_label.append(sample_code)
# lai.append(sample_lai)
# sigma.append(sample_sigma)
# csv_sigmas.append(csv_sigma)
# from matplotlib import pyplot as plt
# plt.scatter(np.array(lai),np.array(sigma),label="lai-tiff_sigma")
# for i in range(len(sigma)):
# plt.annotate(text_label[i], xy = (lai[i], sigma[i])) # 这里xy是需要标记的坐标xytext是对应的标签坐标
#
# plt.scatter(np.array(lai),np.array(csv_sigmas),label="lai-csv_sigmas")
# for i in range(len(csv_sigmas)):
# plt.annotate(text_label[i], xy = (lai[i],csv_sigmas[i])) # 这里xy是需要标记的坐标xytext是对应的标签坐标
# plt.legend()
# plt.show()
return result
def split_sample_list(sample_list,train_ratio):
""" 切分样本比值
Args:
sample_list (list): 样本列表
train_ratio (double): 训练样本的比重
Returns:
list: [sample_train,sample_test]
"""
sample_train=[]
sample_test=[]
n=len(sample_list)
for i in range(n):
if random.random()<=train_ratio:
sample_train.append(sample_list[i])
else:
sample_test.append(sample_list[i])
return [sample_train,sample_test]
2024-10-30 03:16:40 +00:00
def WMC(A,B,M,N,sample_soil,sample_inc,sample_sigma,sample_NDVI):
sigma_soil=M*sample_soil+N
theta=np.cos(sample_inc)
Atheta=A*theta
lnV=(sample_sigma-Atheta)/(sigma_soil-Atheta)
Mveg=-1*(theta/(2*B))*np.log(lnV)
# lai_ndvi=0*E*sample_NDVI+F
# return lai_ndvi
LAI=Mveg#(lai_ndvi+Mveg)/2
return LAI
def WMCModel(param_arr,sample_lai,sample_soil,sample_inc,sample_sigma,sample_NDVI):
""" WMC模型 增加 归一化植被指数
Args:
param_arr (np.ndarray): 参数数组
sample_lai (double): 叶面积指数
sample_soil (double): 土壤含水量
sample_inc (double): 入射角弧度值
sample_sigma (double): 后向散射系数线性值
Returns:
double: 方程值
"""
# 映射参数,方便修改模型
A,B,M,N=param_arr # 在这里修改模型
LAI=WMC(A,B,M,N,sample_soil,sample_inc,sample_sigma,sample_NDVI)
result=LAI-sample_lai
return result
def train_WMCmodel(lai_water_inc_sigma_list, params_X0, train_err_image_path, draw_flag=False):
""" 训练模型参数
Args:
lai_waiter_inc_sigma_list (list): 训练模型使用的样本呢
"""
def f(X):
eqs = []
for lai_water_inc_sigma_item in lai_water_inc_sigma_list:
sample_lai = lai_water_inc_sigma_item[4]
sample_sigma = lai_water_inc_sigma_item[8] # 5: csv_sigma, 8:tiff_sigma
sample_soil = lai_water_inc_sigma_item[6]
sample_inc = lai_water_inc_sigma_item[7]
FVC = lai_water_inc_sigma_item[8]
sample_NDVI = lai_water_inc_sigma_item[9]
eqs.append(WMCModel(X, sample_lai, sample_soil, sample_inc, sample_sigma, sample_NDVI))
return eqs
X0 = params_X0 # 初始值
logger.info(str(X0))
h = leastsq(f, X0)
logger.info(h[0], h[1])
err_f = f(h[0])
x_arr = [lai_waiter_inc_sigma_item[4] for lai_waiter_inc_sigma_item in lai_water_inc_sigma_list]
# 根据误差大小进行排序
logger.info("训练集:\n根据误差输出点序\n数量:{}\n点序\t误差值\t 样点信息".format(str(np.array(err_f).shape)))
for i in np.argsort(np.array(err_f)):
logger.info('{}\t{}\t{}'.format(i, err_f[i], str(lai_water_inc_sigma_list[i])))
logger.info("\n误差点序输出结束\n")
pred_lai = []
for i in range(len(err_f)):
pred_lai.append(err_f[i] + x_arr[i])
if draw_flag:
logger.info(err_f)
logger.info(np.where(np.abs(err_f) < 10))
from matplotlib import pyplot as plt
print(len(err_f), len(x_arr))
# plt.scatter(x_arr, pred_lai)
# plt.title("simulation Sigma and sample sigma")
# plt.xlabel("sample sigma")
# plt.ylabel("simulation sigma")
# plt.xlim(0, 10)
# plt.ylim(0, 10)
# plt.plot([0, 10], [0, 10])
# plt.savefig(train_err_image_path, dpi=600)
# plt.show()
return h[0]