microproduct/vegetationPhenology/VegetationPhenologyAuxData.py

206 lines
6.3 KiB
Python
Raw Permalink Normal View History

2023-08-28 10:17:29 +00:00
# -*- coding: UTF-8 -*-
"""
@Project:SalinityMain.py
@File:AuxData.py
@Function: 读取指标物候辅助数据
@Contact:
@Author:SHJ
@Date:2021/10/21 19:25
@Version:1.0.0
"""
import csv
import numpy as np
import mahotas
import logging
from tool.algorithm.algtools.CoordinateTransformation import lonlat2geo, geo2imagexy
from tool.algorithm.image.ImageHandle import ImageHandler
logger = logging.getLogger("mylog")
class PhenologyMeasCsv:
def __init__(self, csv_path, preprocessed_paras):
self.__csv_path = csv_path
self.__preprocessed_paras = preprocessed_paras
def api_read_phenology_measure(self):
csv_data = self.__readcsv(self.__csv_path)
return self.__trans_measuredata(csv_data)
def class_list(self):
"""
输出csv表中的前三列
"""
reader = csv.reader(open(self.__csv_path, newline=''))
class_list=[]
type_id_name = {}
for line_data in reader:
class_list.append(line_data) # class_list含有四列
for data in class_list[1:]:
type_id = data[2]
type_name = data[3]
if type_id not in type_id_name.keys():
if type_id.strip() != "":
type_id_name.update({type_id: type_name})
return type_id_name
pass
@staticmethod
def __readcsv(csv_path):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
return csv_list[1:]
@staticmethod
def readcsv_dic(csv_path):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
keys = csv_list[0]
datas = csv_list[1:]
data_dic = {}
for key ,n in zip(keys,range(len(keys))):
values = [data for data in datas[n]]
data_dic.update({key: values})
return data_dic
def __trans_measuredata(self, meas_data):
"""
获取多边形区域内所有的点分为训练集数据和测试集数据
:para meas_data: csv读取的实测数据
"""
train_data = []
test_data = []
type_data = {}
for data in meas_data:
point_list = []
data_use_type = data[0]
sar_img_name = data[1]
name = sar_img_name.rstrip('.tar.gz')
dataset, rows, cols = self.__get_para_tif_inf(name)
if data_use_type == 'train':
phenology_id = int(data[2])
phenology_name = data[3]
if phenology_id not in type_data.keys():
type_data.update({phenology_id: phenology_name})
else:
phenology_id = -1
pointList = self.__roiPolygonAnalysis(data[4])
for points in pointList:
poly = []
for point in points:
lon = float(point[0])
lat = float(point[1])
projs = lonlat2geo(dataset, lon, lat)
coord = geo2imagexy(dataset, projs[1], projs[0])
row = round(coord[1])
col = round(coord[0])
if 0 <= row < rows and 0 <= col < cols:
poly.append([row, col])
else:
logger.warning("measure data: %s is beyond tif scope !", data)
point_list = point_list + self.__render(poly)
if data_use_type == 'train':
train_data.append([name, phenology_id, point_list])
elif data_use_type == 'test':
test_data.append([name, phenology_id, point_list])
type_map = []
for n, id in zip(range(len(type_data)), type_data):
type_map.append([n + 1, id, type_data[id]])
return train_data, test_data, type_map
@staticmethod
def __render(poly):
# https://www.cnpython.com/qa/51516
"""Return polygon as grid of points inside polygon.
Input : poly (list of lists)
Output : output (list of lists)
"""
xs, ys = zip(*poly)
minx, maxx = min(xs), max(xs)
miny, maxy = min(ys), max(ys)
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
X = maxx - minx + 1
Y = maxy - miny + 1
grid = np.zeros((X, Y), dtype=np.int8)
mahotas.polygon.fill_polygon(newPoly, grid)
return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
def __get_para_tif_inf(self, tif_name):
"""
获取影像的信息
:para tif_name: 影像名称
"""
tif_path = self.__preprocessed_paras[tif_name + '_HH']
ih = ImageHandler()
dataset = ih.get_dataset(tif_path)
rows = ih.get_img_height(tif_path)
cols = ih.get_img_width(tif_path)
return dataset, rows, cols
@staticmethod
def __roiPolygonAnalysis(roiStr):
"""
将csv的POLY数据转为数组
:para roiStr: poly数据
:return pointList: 保存多边形的list
"""
pointList = []
strContent = roiStr.replace("POLYGON", "")
# 解析轮廓字符串为二维数组
bracketsList = []
strTemp = ''
strList = []
for c in strContent:
if c == '(':
bracketsList.append(c)
continue
elif c == ')':
if len(bracketsList) > 0:
bracketsList.pop(0)
if len(strTemp) > 0:
strList.append(strTemp)
strTemp = ''
else:
strTemp += c
for item in strList:
if len(item) == 0:
continue
pTempList = item.split(',')
pList = []
for row in pTempList:
cells = row.split(' ')
if len(cells) != 2:
continue
point = []
point.append(float(cells[0]))
point.append(float(cells[1]))
pList.append(point)
pointList.append(pList)
return pointList