microproduct-l-sar/landcover-L-SAR/LandCoverAuxData.py

206 lines
6.8 KiB
Python
Raw Permalink Normal View History

2024-01-03 01:42:21 +00:00
# -*- coding: UTF-8 -*-
"""
@Project:LandCover
@File:LandCoverData.py
@Function: 读取指标物候辅助数据
@Contact:
@Author:SHJ
@Date:2021/10/21 19:25
@Version:1.0.0
"""
import random
import csv
import numpy as np
import mahotas
import logging
from tool.algorithm.algtools.CoordinateTransformation import lonlat2geo, geo2imagexy
from tool.algorithm.image.ImageHandle import ImageHandler
logger = logging.getLogger("mylog")
class LandCoverMeasCsv:
"""读取地表覆盖标记数据"""
def __init__(self, csv_path, preprocessed_paras, max_tran__num_per_class):
self.__csv_path = csv_path
self.__preprocessed_paras = preprocessed_paras
self.__max_tran__num_per_class = max_tran__num_per_class
def api_read_measure(self):
"""
读取csv表格数据api函数
"""
csv_data = self.__readcsv(self.__csv_path)
return self.__trans_measuredata(csv_data)
def class_list(self):
"""
输出csv表中的前三列
"""
reader = csv.reader(open(self.__csv_path, newline=''))
class_list=[]
type_id_name = {}
type_id_parent = {}
for line_data in reader:
class_list.append(line_data) # class_list含有四列
for data in class_list[1:]:
type_parent= data[0]
type_id = int(data[1])
type_name = data[2]
if type_id not in type_id_name.keys():
type_id_name.update({type_id: type_name})
type_id_parent.update({type_id: type_parent})
return type_id_name, type_id_parent
pass
@staticmethod
def __readcsv(csv_path):
"""
读取csv表格数据
:para csv_path: csv文件路径
"""
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
return csv_list[1:]
def __trans_measuredata(self, meas_data):
"""
获取多边形区域内所有的点分为训练集数据和测试集数据
:para meas_data: csv读取的实测数据
"""
type_data = {}
n = 1
train_data_list = []
for data in meas_data:
for d in data:
if d == '':
raise Exception('there are empty data!', data)
point_list = []
dataset, rows, cols = self.__get_para_tif_inf()
type_id = int(data[1])
type_name = data[2]
if type_id not in type_data.keys():
train_data_list.append([n, type_id, type_name, []])
type_data.update({type_id: type_name})
n += 1
pointList = self.__roiPolygonAnalysis(data[3])
for points in pointList:
poly = []
for point in points:
lon = float(point[0])
lat = float(point[1])
projs = lonlat2geo(dataset, lon, lat)
coord = geo2imagexy(dataset, projs[1], projs[0])
row = round(coord[1])
col = round(coord[0])
if 0 <= row < rows and 0 <= col < cols:
poly.append([row, col])
else:
logger.warning("measure data: %s is beyond tif scope !", data)
if poly != []:
# point_list.append(self.__render(poly))
# 绘制多边形形状
# import matplotlib.pyplot as plt
# plt.figure(None, (5, 5))
# x, y = zip(*self.__render(poly))
# plt.scatter(x, y)
# x, y = zip(*poly)
# plt.plot(x, y, c="r")
# plt.show()
for train_data in train_data_list:
if train_data[1] == type_id:
train_data[3] = train_data[3] + self.__render(poly)
if train_data[3] == [] :
raise Exception('there are empty data!', train_data)
if len(train_data_list) <= 1:
raise Exception('there is only one label type!', train_data_list)
for train_data in train_data_list:
logger.info(str(train_data[0]) + "," + str(train_data[2]) +"," + "num:" + str(len(train_data[3])))
max_num = self.__max_tran__num_per_class
logger.info("max number =" + str(max_num) + ", random select" + str(max_num) + " point as train data!")
if (len(train_data[3]) > max_num):
train_data[3] = random.sample(train_data[3], max_num)
return train_data_list
@staticmethod
def __render(poly):
# https://www.cnpython.com/qa/51516
"""Return polygon as grid of points inside polygon.
Input : poly (list of lists)
Output : output (list of lists)
"""
xs, ys = zip(*poly)
minx, maxx = min(xs), max(xs)
miny, maxy = min(ys), max(ys)
newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly]
X = maxx - minx + 1
Y = maxy - miny + 1
grid = np.zeros((X, Y), dtype=np.int8)
mahotas.polygon.fill_polygon(newPoly, grid)
data = [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))]
return data
def __get_para_tif_inf(self):
"""
获取影像的信息
:para tif_name: 影像名称
"""
tif_path = self.__preprocessed_paras
ih = ImageHandler()
dataset = ih.get_dataset(tif_path)
rows = ih.get_img_height(tif_path)
cols = ih.get_img_width(tif_path)
return dataset, rows, cols
@staticmethod
def __roiPolygonAnalysis(roiStr):
"""
将csv的POLY数据转为数组
:para roiStr: poly数据
:return pointList: 保存多边形的list
"""
pointList = []
strContent = roiStr.replace("POLYGON", "")
# 解析轮廓字符串为二维数组
bracketsList = []
strTemp = ''
strList = []
for c in strContent:
if c == '(':
bracketsList.append(c)
continue
elif c == ')':
if len(bracketsList) > 0:
bracketsList.pop(0)
if len(strTemp) > 0:
strList.append(strTemp)
strTemp = ''
else:
strTemp += c
for item in strList:
if len(item) == 0:
continue
pTempList = item.split(',')
pList = []
for row in pTempList:
cells = row.split(' ')
if len(cells) != 2:
continue
point = [float(cells[0]), float(cells[1])]
pList.append(point)
pointList.append(pList)
return pointList