155 lines
6.2 KiB
Python
155 lines
6.2 KiB
Python
import glob
|
||
import multiprocessing
|
||
import os
|
||
|
||
import cv2
|
||
from VegetationPhenologyAuxData import PhenoloyMeasCsv_geo
|
||
from tool.algorithm.image.ImageHandle import ImageHandler
|
||
from tool.algorithm.ml.machineLearning import MachineLeaning as ml, MachineLeaning
|
||
from tool.csv.csvHandle import csvHandle
|
||
from tool.algorithm.block.blockprocess import BlockProcess as bp, BlockProcess
|
||
import numpy as np
|
||
from tool.algorithm.algtools.PreProcess import PreProcess as pp
|
||
from tool.file.fileHandle import fileHandle
|
||
file = fileHandle()
|
||
csvh = csvHandle()
|
||
MAX_TRAN_NUM = 100000
|
||
|
||
|
||
def predict_VP(clf, X_test_list, out_tif_name, workspace_processing_path, rows, cols):
|
||
"""
|
||
预测数据
|
||
:param clf : svm模型
|
||
:return X_test_list: 分块测试集影像路径
|
||
"""
|
||
ml = MachineLeaning()
|
||
# 开启多进程处理
|
||
bp = BlockProcess()
|
||
block_size = bp.get_block_size(rows, cols)
|
||
name_d = out_tif_name + '_VTH'
|
||
|
||
block_features_dir = X_test_list
|
||
bp_cover_dir = os.path.join(workspace_processing_path, name_d,
|
||
'pre_result\\') # workspace_processing_path + out_tif_name + '\\'
|
||
file.creat_dirs([bp_cover_dir])
|
||
|
||
processes_num = min([len(block_features_dir), multiprocessing.cpu_count() - 7])
|
||
pool = multiprocessing.Pool(processes=processes_num)
|
||
|
||
for path, n in zip(block_features_dir, range(len(block_features_dir))):
|
||
name = os.path.split(path)[1]
|
||
|
||
band = ImageHandler.get_bands(path)
|
||
if band == 1:
|
||
features_array = np.zeros((1, block_size, block_size), dtype=float)
|
||
feature_array = ImageHandler.get_data(path)
|
||
features_array[0, :, :] = feature_array
|
||
else:
|
||
features_array = ImageHandler.get_data(path)
|
||
|
||
X_test = np.reshape(features_array, (features_array.shape[0], features_array[0].size)).T
|
||
|
||
suffix = '_' + name.split('_')[-4] + "_" + name.split('_')[-3] + "_" + name.split('_')[-2] + "_" + \
|
||
name.split('_')[-1]
|
||
img_path = os.path.join(bp_cover_dir, name_d + suffix) # bp_cover_dir + out_tif_name + suffix
|
||
row_begin = int(name.split('_')[-4])
|
||
col_begin = int(name.split('_')[-2])
|
||
pool.apply_async(ml.predict_blok, (
|
||
clf, X_test, block_size, block_size, img_path, row_begin, col_begin, len(block_features_dir), n))
|
||
# ml.predict_blok(clf, X_test, block_size, block_size, img_path, row_begin, col_begin, len(block_features_dir), n)
|
||
|
||
pool.close()
|
||
pool.join()
|
||
del pool
|
||
|
||
# 合并影像
|
||
data_dir = bp_cover_dir
|
||
out_path = workspace_processing_path[0:-1]
|
||
bp.combine(data_dir, cols, rows, out_path, file_type=['tif'], datetype='float32')
|
||
|
||
# 添加地理信息
|
||
cover_path = os.path.join(out_path,
|
||
name_d + ".tif") # workspace_processing_path + out_tif_name + ".tif"
|
||
# bp.assign_spatial_reference_byfile(self.__ref_img_path, cover_path)
|
||
return cover_path
|
||
|
||
def featuresRoi(featurePath, roi):
|
||
tif_dir = r"D:\BaiduNetdiskDownload\envi-result\cut_tif"
|
||
blockDir = r'D:\BaiduNetdiskDownload\envi-result\block'
|
||
pm = PhenoloyMeasCsv_geo(roi, featurePath, MAX_TRAN_NUM)
|
||
train_data_list = pm.api_read_measure_by_name('0')
|
||
train_data_dic = csvh.trans_landCover_list2dic(train_data_list)
|
||
rows = ImageHandler.get_img_height(featurePath)
|
||
cols = ImageHandler.get_img_width(featurePath)
|
||
im_proj, im_geo, im_data = ImageHandler.read_img(featurePath)
|
||
|
||
dim = ImageHandler.get_bands(featurePath)
|
||
X_train = np.empty(shape=(0, dim))
|
||
Y_train = np.empty(shape=(0, 1))
|
||
ids = train_data_dic['ids']
|
||
positions = train_data_dic['positions']
|
||
|
||
for id, points in zip(ids, positions):
|
||
# for data in train_data_list:
|
||
if points == []:
|
||
raise Exception('data is empty!')
|
||
row, col = zip(*points)
|
||
l = len(points)
|
||
X = np.empty(shape=(l, dim))
|
||
|
||
for n in range(dim):
|
||
feature_array = ImageHandler.get_band_array(featurePath, n+1)
|
||
feature_array[np.isnan(feature_array)] = 0 # 异常值填充为0
|
||
feature_array[np.where(feature_array == -9999)] = 0 # 异常值填充为0
|
||
x = feature_array[row, col].T
|
||
X[:, n] = x
|
||
|
||
Y = np.full((l, 1), id)
|
||
X_train = np.vstack((X_train, X))
|
||
Y_train = np.vstack((Y_train, Y))
|
||
Y_train = Y_train.T[0, :]
|
||
|
||
clf = ml.trainRF(X_train, Y_train)
|
||
|
||
bp().cut(tif_dir, blockDir, out_size=2048)
|
||
|
||
X_test_list = list(glob.glob(os.path.join(blockDir, '*.tif')))
|
||
|
||
product_path = predict_VP(clf, X_test_list, 'geo', blockDir, rows, cols)
|
||
|
||
proj, geo, cover_data = ImageHandler.read_img(product_path)
|
||
# 形态学(闭运算)去roi区域噪点
|
||
cover_data = np.uint8(cover_data)
|
||
kernel = np.ones((10, 10), np.uint8)
|
||
cover_data = cv2.erode(cv2.dilate(cover_data, kernel), kernel)
|
||
cover_data = np.int16(cover_data)
|
||
for id, class_id in zip(train_data_dic['ids'], train_data_dic['class_ids']):
|
||
cover_data[np.where(cover_data == id)] = class_id
|
||
|
||
# cover_data[np.where(im_data[0, :, :] == -9999)] = -9999
|
||
cover_geo_path = os.path.join(blockDir, os.path.basename(product_path).split('.tif')[0] + '-VPtemp.tif')
|
||
ImageHandler.write_img(cover_geo_path, im_proj, im_geo, cover_data, '-9999')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
features = r"D:\micro\SWork\LandCover\Temporary\processing\features_geo\Freeman_Vol_geo.tif"
|
||
roi = r"F:\xibei_LandCover\landCoverSamples.csv"
|
||
featuresRoi(features, roi)
|
||
# oytTil = r"D:\BaiduNetdiskDownload\54Features_land.tif"
|
||
# outP = r"D:\BaiduNetdiskDownload\veg_landcover_re.tif"
|
||
# #
|
||
# im_proj, im_geotrans, im_arr = ImageHandler.read_img(features)
|
||
# _, _, land = ImageHandler.read_img(outP)
|
||
# # pp.resampling_by_scale(land, outP, features)
|
||
# for i in range(im_arr.shape[0]):
|
||
# im_arr[i, :, :][np.where(land==20)] = -9999
|
||
# im_arr[i, :, :][np.where(land==30)] = -9999
|
||
# im_arr[i, :, :][np.where(land==40)] = -9999
|
||
# im_arr[i, :, :][np.where(land==50)] = -9999
|
||
# im_arr[i, :, :][np.where(land==70)] = -9999
|
||
# im_arr[i, :, :][np.where(land==90)] = -9999
|
||
# im_arr[i, :, :][np.where(land==100)] = -9999
|
||
# ImageHandler.write_img(oytTil, im_proj, im_geotrans, im_arr, '-9999')
|
||
# print(11)
|
||
|