microproduct-l-sar/leafAreaIndex-L-SAR/sample_process.py

204 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#
# 样本处理的相关的库
#
from tool.algorithm.image.ImageHandle import ImageHandler
import math
import numpy as np
import random
import scipy
# 最小二乘求解非线性方程组
from scipy.optimize import leastsq,fsolve,root
from osgeo import gdal,gdalconst
import pandas as pds
from scipy import interpolate
from multiprocessing import pool
# 常量声明区域
imageHandler=ImageHandler()
# python 的函数类
def read_sample_csv(csv_path):
""" 读取样本的csv
Args:
csv_path (string): 样本csv的地址绝对路径
return:
[
['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],
['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],......
]
"""
lai_csv=pds.read_csv(csv_path, encoding='GBK') # 代码测试区域
lai_csv=lai_csv.loc[:,['样本号','经度','纬度','叶面积指数',"后向散射系数"]]
result=[]
for i in range(len(lai_csv)):
result.append([
0,
lai_csv.loc[i,'样本号'],
lai_csv.loc[i,'经度'], # lon,x
lai_csv.loc[i,'纬度'], # lat,y
lai_csv.loc[i,'叶面积指数'],
10**(float(lai_csv.loc[i,'后向散射系数'])/10),
])
return result
def read_tiff(tiff_path):
""" 从文件中读取影像
Args:
tiff_path (string): 文件影像路径
"""
im_proj, im_geotrans, im_arr=imageHandler.read_img(tiff_path)
return {
'proj':im_proj,
'geotrans':im_geotrans,
'data':im_arr
}
def ReprojectImages2(in_tiff_path,ref_tiff_path,out_tiff_path,resampleAlg=gdalconst.GRA_Bilinear):
""" 将输入影像重采样到参考影像的范围内
Args:
in_tiff_path (string): 输入影像
ref_tiff_path (string): 参考影像
out_tiff_path (string): 输出地址
resampleAlg (gadlconst): 插值方法
"""
# 若采用gdal.Warp()方法进行重采样
# 获取输出影像信息
inputrasfile = gdal.Open(in_tiff_path, gdal.GA_ReadOnly)
inputProj = inputrasfile.GetProjection()
# 获取参考影像信息
referencefile = gdal.Open(ref_tiff_path, gdal.GA_ReadOnly)
referencefileProj = referencefile.GetProjection()
referencefileTrans = referencefile.GetGeoTransform()
bandreferencefile = referencefile.GetRasterBand(1)
x = referencefile.RasterXSize
y = referencefile.RasterYSize
nbands = referencefile.RasterCount
# 创建重采样输出文件(设置投影及六参数)
driver = gdal.GetDriverByName('GTiff')
output = driver.Create(out_tiff_path, x, y, nbands, bandreferencefile.DataType)
output.SetGeoTransform(referencefileTrans)
output.SetProjection(referencefileProj)
options = gdal.WarpOptions(srcSRS=inputProj, dstSRS=referencefileProj, resampleAlg=gdalconst.GRA_Bilinear)
gdal.Warp(output, in_tiff_path, options=options)
def combine_sample_attr(sample_list,attr_tiff):
""" 构建样本
Args:
sample_list (list): 原样本
attr_tiff (string): 添加的属性数据
Returns:
list:[sample,new_attr]
"""
result=[]
# 因为soil_tiff 的影像的 影像分辨率较低
inv_gt=gdal.InvGeoTransform(attr_tiff['geotrans'])
for sample_item in sample_list:
sample_lon=sample_item[2]
sample_lat=sample_item[3]
sample_in_tiff_x=inv_gt[0]+inv_gt[1]*sample_lon+inv_gt[2]*sample_lat # x
sample_in_tiff_y=inv_gt[3]+inv_gt[4]*sample_lon+inv_gt[5]*sample_lat # y
x_min=int(np.floor(sample_in_tiff_x))
x_max=int(np.ceil(sample_in_tiff_x))
y_min=int(np.floor(sample_in_tiff_y))
y_max=int(np.ceil(sample_in_tiff_y))
if x_min<0 or y_min<0 or x_max>=attr_tiff['data'].shape[1] or y_max>=attr_tiff['data'].shape[0]:
continue
#
"""
f = interpolate.interp2d([0,0,1,1], [0,1,1,0],
[attr_tiff['data'][y_min,x_min],
attr_tiff['data'][y_max,x_min],
attr_tiff['data'][y_max,x_max],
attr_tiff['data'][y_min,x_min]
], kind='linear')
interp_value=f(sample_in_tiff_x-x_min,sample_in_tiff_y-y_min)
sample_item.append(interp_value[0])
"""
# 9x9
x_min=x_min-4 if x_min-9>=0 else 0
y_min=y_min-4 if y_min-9>=0 else 0
x_max=x_max+4 if x_max+4<attr_tiff['data'].shape[1] else attr_tiff['data'].shape[1]
y_max=y_max+4 if y_max+4<attr_tiff['data'].shape[0] else attr_tiff['data'].shape[0]
interp_value=np.mean(attr_tiff['data'][y_min:y_max,x_min:x_max])
sample_item.append(interp_value)
result.append(sample_item)
return result
def check_sample(sample_list):
""" 检查样本值
Args:
sample_list (list): 样本值:[ ['日期', '样方编号', '经度', '纬度', 'LAI','土壤含水量','入射角','后向散射系数'] ]
Returns:
list : 处理之后的样本值
"""
result=[]
for item in sample_list:
if len(item)==10 and np.nan not in item:
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
else:
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
if sample_sigma<=0 or np.isnan(sample_sigma):
continue
if (sample_inc*180/np.pi)>90 or np.isnan(sample_inc):
continue
if sample_soil<=0 or sample_soil>=1 or np.isnan(sample_soil):
continue
if sample_lai<=0 or sample_lai>=20 or np.isnan(sample_lai):
continue
if np.isnan(csv_sigma):
continue
result.append(item)
# 绘制分布图
# lai=[]
# sigma=[]
# csv_sigmas=[]
# text_label=[]
# for item in result:
# if len(item)==10:
# sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
# else:
# sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
# text_label.append(sample_code)
# lai.append(sample_lai)
# sigma.append(sample_sigma)
# csv_sigmas.append(csv_sigma)
# from matplotlib import pyplot as plt
# plt.scatter(np.array(lai),np.array(sigma),label="lai-tiff_sigma")
# for i in range(len(sigma)):
# plt.annotate(text_label[i], xy = (lai[i], sigma[i])) # 这里xy是需要标记的坐标xytext是对应的标签坐标
#
# plt.scatter(np.array(lai),np.array(csv_sigmas),label="lai-csv_sigmas")
# for i in range(len(csv_sigmas)):
# plt.annotate(text_label[i], xy = (lai[i],csv_sigmas[i])) # 这里xy是需要标记的坐标xytext是对应的标签坐标
# plt.legend()
# plt.show()
return result
def split_sample_list(sample_list,train_ratio):
""" 切分样本比值
Args:
sample_list (list): 样本列表
train_ratio (double): 训练样本的比重
Returns:
list: [sample_train,sample_test]
"""
sample_train=[]
sample_test=[]
n=len(sample_list)
for i in range(n):
if random.random()<=train_ratio:
sample_train.append(sample_list[i])
else:
sample_test.append(sample_list[i])
return [sample_train,sample_test]