206 lines
6.3 KiB
Python
206 lines
6.3 KiB
Python
# -*- coding: UTF-8 -*-
|
|
"""
|
|
@Project:SalinityMain.py
|
|
@File:AuxData.py
|
|
@Function: 读取指标物候辅助数据
|
|
@Contact:
|
|
@Author:SHJ
|
|
@Date:2021/10/21 19:25
|
|
@Version:1.0.0
|
|
"""
|
|
import csv
|
|
import numpy as np
|
|
import mahotas
|
|
import logging
|
|
from tool.algorithm.algtools.CoordinateTransformation import lonlat2geo, geo2imagexy
|
|
from tool.algorithm.image.ImageHandle import ImageHandler
|
|
logger = logging.getLogger("mylog")
|
|
|
|
|
|
class PhenologyMeasCsv:
|
|
|
|
def __init__(self, csv_path, preprocessed_paras):
|
|
self.__csv_path = csv_path
|
|
self.__preprocessed_paras = preprocessed_paras
|
|
|
|
def api_read_phenology_measure(self):
|
|
csv_data = self.__readcsv(self.__csv_path)
|
|
return self.__trans_measuredata(csv_data)
|
|
|
|
def class_list(self):
|
|
"""
|
|
输出csv表中的前三列
|
|
"""
|
|
reader = csv.reader(open(self.__csv_path, newline=''))
|
|
class_list=[]
|
|
type_id_name = {}
|
|
for line_data in reader:
|
|
class_list.append(line_data) # class_list含有四列
|
|
for data in class_list[1:]:
|
|
type_id = data[2]
|
|
type_name = data[3]
|
|
|
|
if type_id not in type_id_name.keys():
|
|
if type_id.strip() != "":
|
|
type_id_name.update({type_id: type_name})
|
|
return type_id_name
|
|
pass
|
|
|
|
@staticmethod
|
|
def __readcsv(csv_path):
|
|
"""
|
|
读取csv表格数据
|
|
:para csv_path: csv文件路径
|
|
"""
|
|
reader = csv.reader(open(csv_path, newline=''))
|
|
csv_list = []
|
|
for line_data in reader:
|
|
csv_list.append(line_data)
|
|
return csv_list[1:]
|
|
|
|
@staticmethod
|
|
def readcsv_dic(csv_path):
|
|
"""
|
|
读取csv表格数据
|
|
:para csv_path: csv文件路径
|
|
"""
|
|
reader = csv.reader(open(csv_path, newline=''))
|
|
csv_list = []
|
|
for line_data in reader:
|
|
csv_list.append(line_data)
|
|
keys = csv_list[0]
|
|
datas = csv_list[1:]
|
|
data_dic = {}
|
|
for key ,n in zip(keys,range(len(keys))):
|
|
values = [data for data in datas[n]]
|
|
data_dic.update({key: values})
|
|
return data_dic
|
|
|
|
|
|
def __trans_measuredata(self, meas_data):
|
|
"""
|
|
获取多边形区域内所有的点,分为训练集数据和测试集数据
|
|
:para meas_data: csv读取的实测数据
|
|
"""
|
|
train_data = []
|
|
test_data = []
|
|
type_data = {}
|
|
|
|
for data in meas_data:
|
|
point_list = []
|
|
data_use_type = data[0]
|
|
|
|
sar_img_name = data[1]
|
|
name = sar_img_name.rstrip('.tar.gz')
|
|
dataset, rows, cols = self.__get_para_tif_inf(name)
|
|
|
|
if data_use_type == 'train':
|
|
phenology_id = int(data[2])
|
|
phenology_name = data[3]
|
|
if phenology_id not in type_data.keys():
|
|
type_data.update({phenology_id: phenology_name})
|
|
|
|
else:
|
|
phenology_id = -1
|
|
|
|
pointList = self.__roiPolygonAnalysis(data[4])
|
|
for points in pointList:
|
|
poly = []
|
|
for point in points:
|
|
lon = float(point[0])
|
|
lat = float(point[1])
|
|
projs = lonlat2geo(dataset, lon, lat)
|
|
coord = geo2imagexy(dataset, projs[1], projs[0])
|
|
row = round(coord[1])
|
|
col = round(coord[0])
|
|
if 0 <= row < rows and 0 <= col < cols:
|
|
poly.append([row, col])
|
|
else:
|
|
logger.warning("measure data: %s is beyond tif scope !", data)
|
|
|
|
point_list = point_list + self.__render(poly)
|
|
|
|
if data_use_type == 'train':
|
|
train_data.append([name, phenology_id, point_list])
|
|
elif data_use_type == 'test':
|
|
test_data.append([name, phenology_id, point_list])
|
|
|
|
type_map = []
|
|
for n, id in zip(range(len(type_data)), type_data):
|
|
type_map.append([n + 1, id, type_data[id]])
|
|
|
|
return train_data, test_data, type_map
|
|
|
|
@staticmethod
|
|
def __render(poly):
|
|
# https://www.cnpython.com/qa/51516
|
|
"""Return polygon as grid of points inside polygon.
|
|
Input : poly (list of lists)
|
|
Output : output (list of lists)
|
|
"""
|
|
xs, ys = zip(*poly)
|
|
minx, maxx = min(xs), max(xs)
|
|
miny, maxy = min(ys), max(ys)
|
|
|
|
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
|
|
|
|
X = maxx - minx + 1
|
|
Y = maxy - miny + 1
|
|
|
|
grid = np.zeros((X, Y), dtype=np.int8)
|
|
mahotas.polygon.fill_polygon(newPoly, grid)
|
|
|
|
return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
|
|
|
|
def __get_para_tif_inf(self, tif_name):
|
|
"""
|
|
获取影像的信息
|
|
:para tif_name: 影像名称
|
|
"""
|
|
tif_path = self.__preprocessed_paras[tif_name + '_HH']
|
|
ih = ImageHandler()
|
|
dataset = ih.get_dataset(tif_path)
|
|
rows = ih.get_img_height(tif_path)
|
|
cols = ih.get_img_width(tif_path)
|
|
return dataset, rows, cols
|
|
|
|
@staticmethod
|
|
def __roiPolygonAnalysis(roiStr):
|
|
"""
|
|
将csv的POLY数据转为数组
|
|
:para roiStr: poly数据
|
|
:return pointList: 保存多边形的list
|
|
"""
|
|
pointList = []
|
|
strContent = roiStr.replace("POLYGON", "")
|
|
# 解析轮廓字符串为二维数组
|
|
bracketsList = []
|
|
strTemp = ''
|
|
strList = []
|
|
for c in strContent:
|
|
if c == '(':
|
|
bracketsList.append(c)
|
|
continue
|
|
elif c == ')':
|
|
if len(bracketsList) > 0:
|
|
bracketsList.pop(0)
|
|
if len(strTemp) > 0:
|
|
strList.append(strTemp)
|
|
strTemp = ''
|
|
else:
|
|
strTemp += c
|
|
for item in strList:
|
|
if len(item) == 0:
|
|
continue
|
|
pTempList = item.split(',')
|
|
pList = []
|
|
for row in pTempList:
|
|
cells = row.split(' ')
|
|
if len(cells) != 2:
|
|
continue
|
|
point = []
|
|
point.append(float(cells[0]))
|
|
point.append(float(cells[1]))
|
|
pList.append(point)
|
|
pointList.append(pList)
|
|
return pointList |