412 lines
13 KiB
Python
412 lines
13 KiB
Python
# -*- coding: UTF-8 -*-
|
|
"""
|
|
@Project:SalinityMain.py
|
|
@File:AuxData.py
|
|
@Function: 读取指标物候辅助数据
|
|
@Contact:
|
|
@Author:SHJ
|
|
@Date:2021/10/21 19:25
|
|
@Version:1.0.0
|
|
"""
|
|
import csv
|
|
import numpy as np
|
|
import mahotas
|
|
import logging
|
|
import random
|
|
from tool.algorithm.algtools.CoordinateTransformation import lonlat2geo, geo2imagexy
|
|
from tool.algorithm.image.ImageHandle import ImageHandler
|
|
logger = logging.getLogger("mylog")
|
|
|
|
|
|
class PhenologyMeasCsv:
|
|
|
|
def __init__(self, csv_path, preprocessed_paras):
|
|
self.__csv_path = csv_path
|
|
self.__preprocessed_paras = preprocessed_paras
|
|
|
|
def api_read_phenology_measure(self):
|
|
csv_data = self.__readcsv(self.__csv_path)
|
|
return self.__trans_measuredata(csv_data)
|
|
|
|
def class_list(self):
|
|
"""
|
|
输出csv表中的前三列
|
|
"""
|
|
reader = csv.reader(open(self.__csv_path, newline=''))
|
|
class_list=[]
|
|
type_id_name = {}
|
|
for line_data in reader:
|
|
class_list.append(line_data) # class_list含有四列
|
|
for data in class_list[1:]:
|
|
type_id = data[2]
|
|
type_name = data[3]
|
|
|
|
if type_id not in type_id_name.keys():
|
|
if type_id.strip() != "":
|
|
type_id_name.update({type_id: type_name})
|
|
return type_id_name
|
|
pass
|
|
|
|
@staticmethod
|
|
def __readcsv(csv_path):
|
|
"""
|
|
读取csv表格数据
|
|
:para csv_path: csv文件路径
|
|
"""
|
|
reader = csv.reader(open(csv_path, newline=''))
|
|
csv_list = []
|
|
for line_data in reader:
|
|
csv_list.append(line_data)
|
|
return csv_list[1:]
|
|
|
|
@staticmethod
|
|
def readcsv_dic(csv_path):
|
|
"""
|
|
读取csv表格数据
|
|
:para csv_path: csv文件路径
|
|
"""
|
|
reader = csv.reader(open(csv_path, newline=''))
|
|
csv_list = []
|
|
for line_data in reader:
|
|
csv_list.append(line_data)
|
|
keys = csv_list[0]
|
|
datas = csv_list[1:]
|
|
data_dic = {}
|
|
for key ,n in zip(keys,range(len(keys))):
|
|
values = [data for data in datas[n]]
|
|
data_dic.update({key: values})
|
|
return data_dic
|
|
|
|
|
|
def __trans_measuredata(self, meas_data):
|
|
"""
|
|
获取多边形区域内所有的点,分为训练集数据和测试集数据
|
|
:para meas_data: csv读取的实测数据
|
|
"""
|
|
train_data = []
|
|
test_data = []
|
|
type_data = {}
|
|
|
|
for data in meas_data:
|
|
point_list = []
|
|
data_use_type = data[0]
|
|
|
|
sar_img_name = data[1]
|
|
name = sar_img_name.rstrip('.tar.gz')
|
|
dataset, rows, cols = self.__get_para_tif_inf(name)
|
|
|
|
if data_use_type == 'train':
|
|
phenology_id = int(data[2])
|
|
phenology_name = data[3]
|
|
if phenology_id not in type_data.keys():
|
|
type_data.update({phenology_id: phenology_name})
|
|
|
|
else:
|
|
phenology_id = -1
|
|
|
|
pointList = self.__roiPolygonAnalysis(data[4])
|
|
for points in pointList:
|
|
poly = []
|
|
for point in points:
|
|
lon = float(point[0])
|
|
lat = float(point[1])
|
|
projs = lonlat2geo(dataset, lon, lat)
|
|
coord = geo2imagexy(dataset, projs[1], projs[0])
|
|
row = round(coord[1])
|
|
col = round(coord[0])
|
|
if 0 <= row < rows and 0 <= col < cols:
|
|
poly.append([row, col])
|
|
else:
|
|
logger.warning("measure data: %s is beyond tif scope !", data)
|
|
|
|
point_list = point_list + self.__render(poly)
|
|
|
|
if data_use_type == 'train':
|
|
train_data.append([name, phenology_id, point_list])
|
|
elif data_use_type == 'test':
|
|
test_data.append([name, phenology_id, point_list])
|
|
|
|
type_map = []
|
|
for n, id in zip(range(len(type_data)), type_data):
|
|
type_map.append([n + 1, id, type_data[id]])
|
|
|
|
return train_data, test_data, type_map
|
|
|
|
@staticmethod
|
|
def __render(poly):
|
|
# https://www.cnpython.com/qa/51516
|
|
"""Return polygon as grid of points inside polygon.
|
|
Input : poly (list of lists)
|
|
Output : output (list of lists)
|
|
"""
|
|
xs, ys = zip(*poly)
|
|
minx, maxx = min(xs), max(xs)
|
|
miny, maxy = min(ys), max(ys)
|
|
|
|
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
|
|
|
|
X = maxx - minx + 1
|
|
Y = maxy - miny + 1
|
|
|
|
grid = np.zeros((X, Y), dtype=np.int8)
|
|
mahotas.polygon.fill_polygon(newPoly, grid)
|
|
|
|
return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
|
|
|
|
def __get_para_tif_inf(self, tif_name):
|
|
"""
|
|
获取影像的信息
|
|
:para tif_name: 影像名称
|
|
"""
|
|
tif_path = self.__preprocessed_paras[tif_name + '_HH']
|
|
ih = ImageHandler()
|
|
dataset = ih.get_dataset(tif_path)
|
|
rows = ih.get_img_height(tif_path)
|
|
cols = ih.get_img_width(tif_path)
|
|
return dataset, rows, cols
|
|
|
|
@staticmethod
|
|
def __roiPolygonAnalysis(roiStr):
|
|
"""
|
|
将csv的POLY数据转为数组
|
|
:para roiStr: poly数据
|
|
:return pointList: 保存多边形的list
|
|
"""
|
|
pointList = []
|
|
strContent = roiStr.replace("POLYGON", "")
|
|
# 解析轮廓字符串为二维数组
|
|
bracketsList = []
|
|
strTemp = ''
|
|
strList = []
|
|
for c in strContent:
|
|
if c == '(':
|
|
bracketsList.append(c)
|
|
continue
|
|
elif c == ')':
|
|
if len(bracketsList) > 0:
|
|
bracketsList.pop(0)
|
|
if len(strTemp) > 0:
|
|
strList.append(strTemp)
|
|
strTemp = ''
|
|
else:
|
|
strTemp += c
|
|
for item in strList:
|
|
if len(item) == 0:
|
|
continue
|
|
pTempList = item.split(',')
|
|
pList = []
|
|
for row in pTempList:
|
|
cells = row.split(' ')
|
|
if len(cells) != 2:
|
|
continue
|
|
point = []
|
|
point.append(float(cells[0]))
|
|
point.append(float(cells[1]))
|
|
pList.append(point)
|
|
pointList.append(pList)
|
|
return pointList
|
|
|
|
|
|
|
|
class PhenoloyMeasCsv_geo:
|
|
def __init__(self, csv_path, preprocessed_paras, max_tran__num_per_class=100000):
|
|
self.__csv_path = csv_path
|
|
self.__preprocessed_paras = preprocessed_paras
|
|
self.__max_tran__num_per_class = max_tran__num_per_class
|
|
|
|
def api_read_measure(self):
|
|
"""
|
|
读取csv表格数据api函数
|
|
"""
|
|
csv_data = self.__readcsv(self.__csv_path)
|
|
return self.__trans_measuredata(csv_data)
|
|
|
|
def api_read_measure_by_name(self, name):
|
|
"""
|
|
读取csv表格数据api函数
|
|
"""
|
|
csv_data = self.__readcsv_by_name(self.__csv_path, name)
|
|
return self.__trans_measuredata(csv_data)
|
|
|
|
def class_list(self):
|
|
"""
|
|
输出csv表中的前三列
|
|
"""
|
|
reader = csv.reader(open(self.__csv_path, newline=''))
|
|
class_list=[]
|
|
type_id_name = {}
|
|
type_id_parent = {}
|
|
for line_data in reader:
|
|
class_list.append(line_data) # class_list含有四列
|
|
for data in class_list[1:]:
|
|
type_parent= data[0]
|
|
type_id = int(data[1])
|
|
type_name = data[2]
|
|
|
|
if type_id not in type_id_name.keys():
|
|
type_id_name.update({type_id: type_name})
|
|
type_id_parent.update({type_id: type_parent})
|
|
return type_id_name, type_id_parent
|
|
pass
|
|
|
|
@staticmethod
|
|
def __readcsv(csv_path):
|
|
"""
|
|
读取csv表格数据
|
|
:para csv_path: csv文件路径
|
|
"""
|
|
reader = csv.reader(open(csv_path, newline=''))
|
|
csv_list = []
|
|
for line_data in reader:
|
|
csv_list.append(line_data)
|
|
return csv_list[1:]
|
|
|
|
@staticmethod
|
|
def __readcsv_by_name(csv_path, name):
|
|
"""
|
|
读取csv表格数据
|
|
:para csv_path: csv文件路径
|
|
"""
|
|
reader = csv.reader(open(csv_path, newline=''))
|
|
csv_list = []
|
|
for line_data in reader:
|
|
if name in line_data[0]:
|
|
csv_list.append(line_data)
|
|
return csv_list
|
|
|
|
def __trans_measuredata(self, meas_data):
|
|
"""
|
|
获取多边形区域内所有的点,分为训练集数据和测试集数据
|
|
:para meas_data: csv读取的实测数据
|
|
"""
|
|
type_data = {}
|
|
n = 1
|
|
train_data_list = []
|
|
for data in meas_data:
|
|
|
|
for d in data:
|
|
if d == '':
|
|
raise Exception('there are empty data!', data)
|
|
|
|
point_list = []
|
|
dataset, rows, cols = self.__get_para_tif_inf()
|
|
|
|
type_id = int(data[1])
|
|
type_name = data[2]
|
|
if type_id not in type_data.keys():
|
|
train_data_list.append([n, type_id, type_name, []])
|
|
type_data.update({type_id: type_name})
|
|
n += 1
|
|
|
|
pointList = self.__roiPolygonAnalysis(data[3])
|
|
for points in pointList:
|
|
poly = []
|
|
for point in points:
|
|
lon = float(point[0])
|
|
lat = float(point[1])
|
|
# projs = lonlat2geo(dataset, lon, lat)
|
|
coord = geo2imagexy(dataset, lon, lat)
|
|
row = round(coord[1])
|
|
col = round(coord[0])
|
|
if 0 <= row < rows and 0 <= col < cols:
|
|
poly.append([row, col])
|
|
else:
|
|
logger.warning("point %s is beyond tif scope, in measure data: %s !", point, data)
|
|
if poly != []:
|
|
point_list.append(self.__render(poly))
|
|
for train_data in train_data_list:
|
|
if train_data[1] == type_id:
|
|
train_data[3] = train_data[3] + self.__render(poly)
|
|
if train_data[3] == [] :
|
|
raise Exception('there are empty data!', train_data)
|
|
|
|
if len(train_data_list) <= 1:
|
|
raise Exception('there is only one label type!', train_data_list)
|
|
|
|
num_list = []
|
|
for train_data in train_data_list:
|
|
if not len(train_data[3]) == 0:
|
|
num_list.append(len(train_data[3]))
|
|
max_num = np.min(num_list)
|
|
for train_data in train_data_list:
|
|
logger.info(str(train_data[0]) + "," + str(train_data[2]) +"," + "num:" + str(len(train_data[3])))
|
|
# max_num = self.__max_tran__num_per_class
|
|
logger.info("max number =" + str(max_num) +", random select"+str(max_num)+" point as train data!")
|
|
if(len(train_data[3]) > max_num):
|
|
train_data[3] = random.sample(train_data[3], max_num)
|
|
|
|
|
|
return train_data_list
|
|
|
|
@staticmethod
|
|
def __render(poly):
|
|
# https://www.cnpython.com/qa/51516
|
|
"""Return polygon as grid of points inside polygon.
|
|
Input : poly (list of lists)
|
|
Output : output (list of lists)
|
|
"""
|
|
xs, ys = zip(*poly)
|
|
minx, maxx = min(xs), max(xs)
|
|
miny, maxy = min(ys), max(ys)
|
|
|
|
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
|
|
|
|
X = maxx - minx + 1
|
|
Y = maxy - miny + 1
|
|
|
|
grid = np.zeros((X, Y), dtype=np.int8)
|
|
mahotas.polygon.fill_polygon(newPoly, grid)
|
|
|
|
return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
|
|
|
|
|
|
def __get_para_tif_inf(self):
|
|
"""
|
|
获取影像的信息
|
|
:para tif_name: 影像名称
|
|
"""
|
|
tif_path = self.__preprocessed_paras
|
|
ih = ImageHandler()
|
|
dataset = ih.get_dataset(tif_path)
|
|
rows = ih.get_img_height(tif_path)
|
|
cols = ih.get_img_width(tif_path)
|
|
return dataset, rows, cols
|
|
|
|
@staticmethod
|
|
def __roiPolygonAnalysis(roiStr):
|
|
"""
|
|
将csv的POLY数据转为数组
|
|
:para roiStr: poly数据
|
|
:return pointList: 保存多边形的list
|
|
"""
|
|
pointList = []
|
|
strContent = roiStr.replace("POLYGON", "")
|
|
# 解析轮廓字符串为二维数组
|
|
bracketsList = []
|
|
strTemp = ''
|
|
strList = []
|
|
for c in strContent:
|
|
if c == '(':
|
|
bracketsList.append(c)
|
|
continue
|
|
elif c == ')':
|
|
if len(bracketsList) > 0:
|
|
bracketsList.pop(0)
|
|
if len(strTemp) > 0:
|
|
strList.append(strTemp)
|
|
strTemp = ''
|
|
else:
|
|
strTemp += c
|
|
for item in strList:
|
|
if len(item) == 0:
|
|
continue
|
|
pTempList = item.split(',')
|
|
pList = []
|
|
for row in pTempList:
|
|
cells = row.split(' ')
|
|
if len(cells) != 2:
|
|
continue
|
|
point = [float(cells[0]), float(cells[1])]
|
|
pList.append(point)
|
|
pointList.append(pList)
|
|
return pointList |