microproduct/vegetationPhenology/VegetationPhenologyAuxData.py

409 lines
13 KiB
Python

# -*- coding: UTF-8 -*-
"""
@Project:SalinityMain.py
@File:AuxData.py
@Function: 读取指标物候辅助数据
@Contact:
@Author:SHJ
@Date:2021/10/21 19:25
@Version:1.0.0
"""
import csv
import numpy as np
import mahotas
import logging
import random
from tool.algorithm.algtools.CoordinateTransformation import lonlat2geo, geo2imagexy
from tool.algorithm.image.ImageHandle import ImageHandler
logger = logging.getLogger("mylog")
class PhenologyMeasCsv:
def __init__(self, csv_path, preprocessed_paras):
self.__csv_path = csv_path
self.__preprocessed_paras = preprocessed_paras
def api_read_phenology_measure(self):
csv_data = self.__readcsv(self.__csv_path)
return self.__trans_measuredata(csv_data)
def class_list(self):
"""
输出csv表中的前三列
"""
reader = csv.reader(open(self.__csv_path, newline=''))
class_list=[]
type_id_name = {}
for line_data in reader:
class_list.append(line_data) # class_list含有四列
for data in class_list[1:]:
type_id = data[2]
type_name = data[3]
if type_id not in type_id_name.keys():
if type_id.strip() != "":
type_id_name.update({type_id: type_name})
return type_id_name
pass
@staticmethod
def __readcsv(csv_path):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
return csv_list[1:]
@staticmethod
def readcsv_dic(csv_path):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
keys = csv_list[0]
datas = csv_list[1:]
data_dic = {}
for key ,n in zip(keys,range(len(keys))):
values = [data for data in datas[n]]
data_dic.update({key: values})
return data_dic
def __trans_measuredata(self, meas_data):
"""
获取多边形区域内所有的点,分为训练集数据和测试集数据
:para meas_data: csv读取的实测数据
"""
train_data = []
test_data = []
type_data = {}
for data in meas_data:
point_list = []
data_use_type = data[0]
sar_img_name = data[1]
name = sar_img_name.rstrip('.tar.gz')
dataset, rows, cols = self.__get_para_tif_inf(name)
if data_use_type == 'train':
phenology_id = int(data[2])
phenology_name = data[3]
if phenology_id not in type_data.keys():
type_data.update({phenology_id: phenology_name})
else:
phenology_id = -1
pointList = self.__roiPolygonAnalysis(data[4])
for points in pointList:
poly = []
for point in points:
lon = float(point[0])
lat = float(point[1])
projs = lonlat2geo(dataset, lon, lat)
coord = geo2imagexy(dataset, projs[1], projs[0])
row = round(coord[1])
col = round(coord[0])
if 0 <= row < rows and 0 <= col < cols:
poly.append([row, col])
else:
logger.warning("measure data: %s is beyond tif scope !", data)
point_list = point_list + self.__render(poly)
if data_use_type == 'train':
train_data.append([name, phenology_id, point_list])
elif data_use_type == 'test':
test_data.append([name, phenology_id, point_list])
type_map = []
for n, id in zip(range(len(type_data)), type_data):
type_map.append([n + 1, id, type_data[id]])
return train_data, test_data, type_map
@staticmethod
def __render(poly):
# https://www.cnpython.com/qa/51516
"""Return polygon as grid of points inside polygon.
Input : poly (list of lists)
Output : output (list of lists)
"""
xs, ys = zip(*poly)
minx, maxx = min(xs), max(xs)
miny, maxy = min(ys), max(ys)
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
X = maxx - minx + 1
Y = maxy - miny + 1
grid = np.zeros((X, Y), dtype=np.int8)
mahotas.polygon.fill_polygon(newPoly, grid)
return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
def __get_para_tif_inf(self, tif_name):
"""
获取影像的信息
:para tif_name: 影像名称
"""
tif_path = self.__preprocessed_paras[tif_name + '_HH']
ih = ImageHandler()
dataset = ih.get_dataset(tif_path)
rows = ih.get_img_height(tif_path)
cols = ih.get_img_width(tif_path)
return dataset, rows, cols
@staticmethod
def __roiPolygonAnalysis(roiStr):
"""
将csv的POLY数据转为数组
:para roiStr: poly数据
:return pointList: 保存多边形的list
"""
pointList = []
strContent = roiStr.replace("POLYGON", "")
# 解析轮廓字符串为二维数组
bracketsList = []
strTemp = ''
strList = []
for c in strContent:
if c == '(':
bracketsList.append(c)
continue
elif c == ')':
if len(bracketsList) > 0:
bracketsList.pop(0)
if len(strTemp) > 0:
strList.append(strTemp)
strTemp = ''
else:
strTemp += c
for item in strList:
if len(item) == 0:
continue
pTempList = item.split(',')
pList = []
for row in pTempList:
cells = row.split(' ')
if len(cells) != 2:
continue
point = []
point.append(float(cells[0]))
point.append(float(cells[1]))
pList.append(point)
pointList.append(pList)
return pointList
class PhenoloyMeasCsv_geo:
def __init__(self, csv_path, preprocessed_paras, max_tran__num_per_class=100000):
self.__csv_path = csv_path
self.__preprocessed_paras = preprocessed_paras
self.__max_tran__num_per_class = max_tran__num_per_class
def api_read_measure(self):
"""
读取csv表格数据api函数
"""
csv_data = self.__readcsv(self.__csv_path)
return self.__trans_measuredata(csv_data)
def api_read_measure_by_name(self, name):
"""
读取csv表格数据api函数
"""
csv_data = self.__readcsv_by_name(self.__csv_path, name)
return self.__trans_measuredata(csv_data)
def class_list(self):
"""
输出csv表中的前三列
"""
reader = csv.reader(open(self.__csv_path, newline=''))
class_list=[]
type_id_name = {}
type_id_parent = {}
for line_data in reader:
class_list.append(line_data) # class_list含有四列
for data in class_list[1:]:
type_parent= data[0]
type_id = int(data[1])
type_name = data[2]
if type_id not in type_id_name.keys():
type_id_name.update({type_id: type_name})
type_id_parent.update({type_id: type_parent})
return type_id_name, type_id_parent
pass
@staticmethod
def __readcsv(csv_path):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
return csv_list[1:]
@staticmethod
def __readcsv_by_name(csv_path, name):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
if name in line_data[0]:
csv_list.append(line_data)
return csv_list
def __trans_measuredata(self, meas_data):
"""
获取多边形区域内所有的点,分为训练集数据和测试集数据
:para meas_data: csv读取的实测数据
"""
type_data = {}
n = 1
train_data_list = []
for data in meas_data:
for d in data:
if d == '':
raise Exception('there are empty data!', data)
point_list = []
dataset, rows, cols = self.__get_para_tif_inf()
type_id = int(data[1])
type_name = data[2]
if type_id not in type_data.keys():
train_data_list.append([n, type_id, type_name, []])
type_data.update({type_id: type_name})
n += 1
pointList = self.__roiPolygonAnalysis(data[3])
for points in pointList:
poly = []
for point in points:
lon = float(point[0])
lat = float(point[1])
# projs = lonlat2geo(dataset, lon, lat)
coord = geo2imagexy(dataset, lon, lat)
row = round(coord[1])
col = round(coord[0])
if 0 <= row < rows and 0 <= col < cols:
poly.append([row, col])
else:
logger.warning("point %s is beyond tif scope, in measure data: %s !", point, data)
if poly != []:
point_list.append(self.__render(poly))
for train_data in train_data_list:
if train_data[1] == type_id:
train_data[3] = train_data[3] + self.__render(poly)
if train_data[3] == [] :
raise Exception('there are empty data!', train_data)
num_list = []
for train_data in train_data_list:
num_list.append(len(train_data[3]))
max_num = np.min(num_list)
for train_data in train_data_list:
logger.info(str(train_data[0]) + "," + str(train_data[2]) +"," + "num:" + str(len(train_data[3])))
# max_num = self.__max_tran__num_per_class
logger.info("max number =" + str(max_num) +", random select"+str(max_num)+" point as train data!")
if(len(train_data[3]) > max_num):
train_data[3] = random.sample(train_data[3], max_num)
if len(train_data_list) <= 1:
raise Exception('there is only one label type!', train_data_list)
return train_data_list
@staticmethod
def __render(poly):
# https://www.cnpython.com/qa/51516
"""Return polygon as grid of points inside polygon.
Input : poly (list of lists)
Output : output (list of lists)
"""
xs, ys = zip(*poly)
minx, maxx = min(xs), max(xs)
miny, maxy = min(ys), max(ys)
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
X = maxx - minx + 1
Y = maxy - miny + 1
grid = np.zeros((X, Y), dtype=np.int8)
mahotas.polygon.fill_polygon(newPoly, grid)
return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
def __get_para_tif_inf(self):
"""
获取影像的信息
:para tif_name: 影像名称
"""
tif_path = self.__preprocessed_paras
ih = ImageHandler()
dataset = ih.get_dataset(tif_path)
rows = ih.get_img_height(tif_path)
cols = ih.get_img_width(tif_path)
return dataset, rows, cols
@staticmethod
def __roiPolygonAnalysis(roiStr):
"""
将csv的POLY数据转为数组
:para roiStr: poly数据
:return pointList: 保存多边形的list
"""
pointList = []
strContent = roiStr.replace("POLYGON", "")
# 解析轮廓字符串为二维数组
bracketsList = []
strTemp = ''
strList = []
for c in strContent:
if c == '(':
bracketsList.append(c)
continue
elif c == ')':
if len(bracketsList) > 0:
bracketsList.pop(0)
if len(strTemp) > 0:
strList.append(strTemp)
strTemp = ''
else:
strTemp += c
for item in strList:
if len(item) == 0:
continue
pTempList = item.split(',')
pList = []
for row in pTempList:
cells = row.split(' ')
if len(cells) != 2:
continue
point = [float(cells[0]), float(cells[1])]
pList.append(point)
pointList.append(pList)
return pointList