202 lines
7.2 KiB
Python
202 lines
7.2 KiB
Python
#
|
||
# 样本处理的相关的库
|
||
#
|
||
|
||
from tool.algorithm.image.ImageHandle import ImageHandler
|
||
import math
|
||
import numpy as np
|
||
import random
|
||
import scipy
|
||
# 最小二乘求解非线性方程组
|
||
from scipy.optimize import leastsq,fsolve,root
|
||
from osgeo import gdal,gdalconst
|
||
import pandas as pds
|
||
from scipy import interpolate
|
||
from multiprocessing import pool
|
||
# 常量声明区域
|
||
imageHandler=ImageHandler()
|
||
|
||
|
||
# python 的函数类
|
||
def read_sample_csv(csv_path):
|
||
""" 读取样本的csv
|
||
Args:
|
||
csv_path (string): 样本csv的地址,绝对路径
|
||
return:
|
||
[
|
||
['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],
|
||
['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],......
|
||
]
|
||
"""
|
||
lai_csv=pds.read_csv(csv_path)# 代码测试区域
|
||
lai_csv=lai_csv.loc[:,['id','lon','lat','leaf',"cal"]]
|
||
result=[]
|
||
for i in range(len(lai_csv)):
|
||
result.append([
|
||
0,
|
||
lai_csv.loc[i,'id'],
|
||
lai_csv.loc[i,'lon'], # lon,x
|
||
lai_csv.loc[i,'lat'], # lat,y
|
||
lai_csv.loc[i,'leaf'],
|
||
10**(float(lai_csv.loc[i,'cal'])/10),
|
||
])
|
||
return result
|
||
def read_tiff(tiff_path):
|
||
""" 从文件中读取影像
|
||
|
||
Args:
|
||
tiff_path (string): 文件影像路径
|
||
"""
|
||
im_proj, im_geotrans, im_arr=imageHandler.read_img(tiff_path)
|
||
return {
|
||
'proj':im_proj,
|
||
'geotrans':im_geotrans,
|
||
'data':im_arr
|
||
}
|
||
|
||
def ReprojectImages2(in_tiff_path,ref_tiff_path,out_tiff_path,resampleAlg=gdalconst.GRA_Bilinear):
|
||
""" 将输入影像重采样到参考影像的范围内
|
||
|
||
Args:
|
||
in_tiff_path (string): 输入影像
|
||
ref_tiff_path (string): 参考影像
|
||
out_tiff_path (string): 输出地址
|
||
resampleAlg (gadlconst): 插值方法
|
||
"""
|
||
# 若采用gdal.Warp()方法进行重采样
|
||
# 获取输出影像信息
|
||
inputrasfile = gdal.Open(in_tiff_path, gdal.GA_ReadOnly)
|
||
inputProj = inputrasfile.GetProjection()
|
||
# 获取参考影像信息
|
||
referencefile = gdal.Open(ref_tiff_path, gdal.GA_ReadOnly)
|
||
referencefileProj = referencefile.GetProjection()
|
||
referencefileTrans = referencefile.GetGeoTransform()
|
||
bandreferencefile = referencefile.GetRasterBand(1)
|
||
x = referencefile.RasterXSize
|
||
y = referencefile.RasterYSize
|
||
nbands = referencefile.RasterCount
|
||
# 创建重采样输出文件(设置投影及六参数)
|
||
driver = gdal.GetDriverByName('GTiff')
|
||
output = driver.Create(out_tiff_path, x, y, nbands, bandreferencefile.DataType)
|
||
output.SetGeoTransform(referencefileTrans)
|
||
output.SetProjection(referencefileProj)
|
||
options = gdal.WarpOptions(srcSRS=inputProj, dstSRS=referencefileProj, resampleAlg=gdalconst.GRA_Bilinear)
|
||
gdal.Warp(output, in_tiff_path, options=options)
|
||
|
||
def combine_sample_attr(sample_list,attr_tiff):
|
||
""" 构建样本
|
||
|
||
Args:
|
||
sample_list (list): 原样本
|
||
attr_tiff (string): 添加的属性数据
|
||
|
||
Returns:
|
||
list:[sample,new_attr]
|
||
"""
|
||
result=[]
|
||
# 因为soil_tiff 的影像的 影像分辨率较低
|
||
inv_gt=gdal.InvGeoTransform(attr_tiff['geotrans'])
|
||
for sample_item in sample_list:
|
||
sample_lon=sample_item[2]
|
||
sample_lat=sample_item[3]
|
||
sample_in_tiff_x=inv_gt[0]+inv_gt[1]*sample_lon+inv_gt[2]*sample_lat # x
|
||
sample_in_tiff_y=inv_gt[3]+inv_gt[4]*sample_lon+inv_gt[5]*sample_lat # y
|
||
x_min=int(np.floor(sample_in_tiff_x))
|
||
x_max=int(np.ceil(sample_in_tiff_x))
|
||
y_min=int(np.floor(sample_in_tiff_y))
|
||
y_max=int(np.ceil(sample_in_tiff_y))
|
||
if x_min<0 or y_min<0 or x_max>=attr_tiff['data'].shape[1] or y_max>=attr_tiff['data'].shape[0]:
|
||
continue
|
||
#
|
||
"""
|
||
f = interpolate.interp2d([0,0,1,1], [0,1,1,0],
|
||
[attr_tiff['data'][y_min,x_min],
|
||
attr_tiff['data'][y_max,x_min],
|
||
attr_tiff['data'][y_max,x_max],
|
||
attr_tiff['data'][y_min,x_min]
|
||
], kind='linear')
|
||
interp_value=f(sample_in_tiff_x-x_min,sample_in_tiff_y-y_min)
|
||
sample_item.append(interp_value[0])
|
||
"""
|
||
# 9x9
|
||
x_min=x_min-4 if x_min-9>=0 else 0
|
||
y_min=y_min-4 if y_min-9>=0 else 0
|
||
x_max=x_max+4 if x_max+4<attr_tiff['data'].shape[1] else attr_tiff['data'].shape[1]
|
||
y_max=y_max+4 if y_max+4<attr_tiff['data'].shape[0] else attr_tiff['data'].shape[0]
|
||
interp_value=np.mean(attr_tiff['data'][y_min:y_max,x_min:x_max])
|
||
sample_item.append(interp_value)
|
||
result.append(sample_item)
|
||
return result
|
||
|
||
def check_sample(sample_list):
|
||
""" 检查样本值
|
||
|
||
Args:
|
||
sample_list (list): 样本值:[ ['日期', '样方编号', '经度', '纬度', 'LAI','土壤含水量','入射角','后向散射系数'] ]
|
||
|
||
Returns:
|
||
list : 处理之后的样本值
|
||
"""
|
||
result=[]
|
||
for item in sample_list:
|
||
if len(item)==10:
|
||
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
|
||
else:
|
||
sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
|
||
if sample_sigma<=0:
|
||
continue
|
||
if (sample_inc*180/np.pi)>90:
|
||
continue
|
||
if sample_soil<=0 or sample_soil>=1:
|
||
continue
|
||
if sample_lai<=0 or sample_lai>=20:
|
||
continue
|
||
result.append(item)
|
||
# 绘制分布图
|
||
# lai=[]
|
||
# sigma=[]
|
||
# csv_sigmas=[]
|
||
# text_label=[]
|
||
# for item in result:
|
||
# if len(item)==10:
|
||
# sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
|
||
# else:
|
||
# sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
|
||
# text_label.append(sample_code)
|
||
# lai.append(sample_lai)
|
||
# sigma.append(sample_sigma)
|
||
# csv_sigmas.append(csv_sigma)
|
||
# from matplotlib import pyplot as plt
|
||
# plt.scatter(np.array(lai),np.array(sigma),label="lai-tiff_sigma")
|
||
# for i in range(len(sigma)):
|
||
# plt.annotate(text_label[i], xy = (lai[i], sigma[i])) # 这里xy是需要标记的坐标,xytext是对应的标签坐标
|
||
#
|
||
# plt.scatter(np.array(lai),np.array(csv_sigmas),label="lai-csv_sigmas")
|
||
# for i in range(len(csv_sigmas)):
|
||
# plt.annotate(text_label[i], xy = (lai[i],csv_sigmas[i])) # 这里xy是需要标记的坐标,xytext是对应的标签坐标
|
||
# plt.legend()
|
||
# plt.show()
|
||
|
||
return result
|
||
|
||
|
||
|
||
def split_sample_list(sample_list,train_ratio):
|
||
""" 切分样本比值
|
||
|
||
Args:
|
||
sample_list (list): 样本列表
|
||
train_ratio (double): 训练样本的比重
|
||
|
||
Returns:
|
||
list: [sample_train,sample_test]
|
||
"""
|
||
sample_train=[]
|
||
sample_test=[]
|
||
n=len(sample_list)
|
||
for i in range(n):
|
||
if random.random()<=train_ratio:
|
||
sample_train.append(sample_list[i])
|
||
else:
|
||
sample_test.append(sample_list[i])
|
||
return [sample_train,sample_test] |