265 lines
9.4 KiB
Python
265 lines
9.4 KiB
Python
|
# -*- coding: UTF-8 -*-
|
|||
|
"""
|
|||
|
@Project : microproduct
|
|||
|
@File : csvHandle.py
|
|||
|
@Function : 读写csv文件
|
|||
|
@Contact :
|
|||
|
@Author:SHJ
|
|||
|
@Date:2022/11/6
|
|||
|
@Version:1.0.0
|
|||
|
"""
|
|||
|
import random
|
|||
|
import csv
|
|||
|
import logging
|
|||
|
import numpy as np
|
|||
|
from tool.algorithm.image.ImageHandle import ImageHandler
|
|||
|
from tool.algorithm.algtools.CoordinateTransformation import geo2imagexy
|
|||
|
from tool.algorithm.transforml1a.transHandle import TransImgL1A
|
|||
|
logger = logging.getLogger("mylog")
|
|||
|
|
|||
|
|
|||
|
class csvHandle:
|
|||
|
def __init__(self, row=0, col=0):
|
|||
|
self.imageHandler = ImageHandler()
|
|||
|
self.row = row
|
|||
|
self.col = col
|
|||
|
self.img_falg = False
|
|||
|
if row != 0 and col != 0:
|
|||
|
self.roi_img = np.zeros((row, col), dtype=float)
|
|||
|
self.img_falg = True
|
|||
|
|
|||
|
def get_roi_img(self):
|
|||
|
if self.img_falg:
|
|||
|
self.roi_img[self.roi_img == 0] = np.nan
|
|||
|
return self.roi_img
|
|||
|
else:
|
|||
|
return np.array([])
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def readcsv(csv_path):
|
|||
|
reader = csv.reader(open(csv_path, newline=''))
|
|||
|
csv_list = []
|
|||
|
for line_data in reader:
|
|||
|
csv_list.append(line_data)
|
|||
|
return csv_list[1:]
|
|||
|
|
|||
|
def trans_measuredata(self, meas_data, tif_path):
|
|||
|
file_name = tif_path
|
|||
|
dataset = self.imageHandler.get_dataset(file_name)
|
|||
|
rows = self.imageHandler.get_img_height(file_name)
|
|||
|
cols = self.imageHandler.get_img_width(file_name)
|
|||
|
measdata_list = []
|
|||
|
logger.info('[MEASURE DATA]')
|
|||
|
for data in meas_data:
|
|||
|
lon = float(data[1])
|
|||
|
lat = float(data[2])
|
|||
|
coord = geo2imagexy(dataset, lon, lat)
|
|||
|
row = round(coord[1])
|
|||
|
col = round(coord[0])
|
|||
|
|
|||
|
if row >= 0 and row <= rows and col >= 0 and col <= cols:
|
|||
|
measdata_list.append([row, col, float(data[3])])
|
|||
|
logger.info([row, col, float(data[3])])
|
|||
|
else:
|
|||
|
logger.warning("measure data: %s is beyond tif scope !", data)
|
|||
|
pass
|
|||
|
return measdata_list
|
|||
|
|
|||
|
def write_roi_img_data(self, points, type_id):
|
|||
|
if self.img_falg:
|
|||
|
for p in points:
|
|||
|
r = p[0]
|
|||
|
c = p[1]
|
|||
|
if r < self.row and c < self.col:
|
|||
|
self.roi_img[r, c] = type_id
|
|||
|
|
|||
|
|
|||
|
def trans_landCover_measuredata(self, meas_data, cuted_ori_sim_path, max_train_num =100000):
|
|||
|
"""
|
|||
|
获取多边形区域内所有的点,分为训练集数据和测试集数据
|
|||
|
:para meas_data: csv读取的实测数据
|
|||
|
"""
|
|||
|
type_data = {}
|
|||
|
n = 1
|
|||
|
train_data_list = []
|
|||
|
for data in meas_data:
|
|||
|
for d in data:
|
|||
|
if d == '':
|
|||
|
raise Exception('there are empty data!', data)
|
|||
|
|
|||
|
type_id = int(data[1])
|
|||
|
type_name = data[2]
|
|||
|
if type_id not in type_data.keys():
|
|||
|
train_data_list.append([n, type_id, type_name, []])
|
|||
|
type_data.update({type_id: type_name})
|
|||
|
n += 1
|
|||
|
|
|||
|
pointList = self.__roiPolygonAnalysis(data[3])
|
|||
|
for points in pointList:
|
|||
|
roi_poly = [(float(lon), float(lat)) for (lon, lat) in points]
|
|||
|
tr = TransImgL1A(cuted_ori_sim_path, roi_poly)
|
|||
|
if tr._mask is not None:
|
|||
|
points = tr.get_roi_points()
|
|||
|
for train_data in train_data_list:
|
|||
|
if train_data[1] == type_id:
|
|||
|
train_data[3] += points
|
|||
|
self.write_roi_img_data(points, type_id)
|
|||
|
if train_data[3] == [] :
|
|||
|
raise Exception('there are empty data!', train_data)
|
|||
|
if len(train_data_list) <= 1:
|
|||
|
raise Exception('there is only one label type!', train_data_list)
|
|||
|
|
|||
|
for train_data in train_data_list:
|
|||
|
logger.info(str(train_data[0]) + "," + str(train_data[2]) + "," + "num:" + str(len(train_data[3])))
|
|||
|
max_num = max_train_num
|
|||
|
if (len(train_data[3]) > max_num):
|
|||
|
logger.info("max number =" + str(max_num) + ", random select" + str(max_num) + " point as train data!")
|
|||
|
train_data[3] = random.sample(train_data[3], max_num)
|
|||
|
|
|||
|
return train_data_list
|
|||
|
|
|||
|
def trans_landCover_measuredata_dic(self, meas_data, cuted_ori_sim_path,max_train_num=100000):
|
|||
|
train_data_list = self.trans_landCover_measuredata(meas_data, cuted_ori_sim_path,max_train_num)
|
|||
|
return self.trans_landCover_list2dic(train_data_list)
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def trans_landCover_list2dic(train_data_list):
|
|||
|
ids = []
|
|||
|
class_ids = []
|
|||
|
ch_names = []
|
|||
|
positions = []
|
|||
|
for data in train_data_list:
|
|||
|
ids.append(data[0])
|
|||
|
class_ids.append(data[1])
|
|||
|
ch_names.append(data[2])
|
|||
|
positions.append(data[3])
|
|||
|
|
|||
|
train_data_dic = {}
|
|||
|
train_data_dic.update({"ids": ids})
|
|||
|
train_data_dic.update({"class_ids": class_ids})
|
|||
|
train_data_dic.update({"ch_names": ch_names})
|
|||
|
train_data_dic.update({"positions": positions})
|
|||
|
return train_data_dic
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def __roiPolygonAnalysis(roiStr):
|
|||
|
"""
|
|||
|
将csv的POLY数据转为数组
|
|||
|
:para roiStr: poly数据
|
|||
|
:return pointList: 保存多边形的list
|
|||
|
"""
|
|||
|
pointList = []
|
|||
|
strContent = roiStr.replace("POLYGON", "")
|
|||
|
# 解析轮廓字符串为二维数组
|
|||
|
bracketsList = []
|
|||
|
strTemp = ''
|
|||
|
strList = []
|
|||
|
for c in strContent:
|
|||
|
if c == '(':
|
|||
|
bracketsList.append(c)
|
|||
|
continue
|
|||
|
elif c == ')':
|
|||
|
if len(bracketsList) > 0:
|
|||
|
bracketsList.pop(0)
|
|||
|
if len(strTemp) > 0:
|
|||
|
strList.append(strTemp)
|
|||
|
strTemp = ''
|
|||
|
else:
|
|||
|
strTemp += c
|
|||
|
for item in strList:
|
|||
|
if len(item) == 0:
|
|||
|
continue
|
|||
|
pTempList = item.split(',')
|
|||
|
pList = []
|
|||
|
for row in pTempList:
|
|||
|
cells = row.split(' ')
|
|||
|
if len(cells) != 2:
|
|||
|
continue
|
|||
|
point = [float(cells[0]), float(cells[1])]
|
|||
|
pList.append(point)
|
|||
|
pointList.append(pList)
|
|||
|
return pointList
|
|||
|
|
|||
|
def class_landcover_list(self, csv_path):
|
|||
|
"""
|
|||
|
输出csv表中的前三列
|
|||
|
"""
|
|||
|
reader = csv.reader(open(csv_path, newline=''))
|
|||
|
class_list=[]
|
|||
|
type_id_name = {}
|
|||
|
type_id_parent = {}
|
|||
|
for line_data in reader:
|
|||
|
class_list.append(line_data) # class_list含有四列
|
|||
|
for data in class_list[1:]:
|
|||
|
type_parent= data[0]
|
|||
|
type_id = int(data[1])
|
|||
|
type_name = data[2]
|
|||
|
|
|||
|
if type_id not in type_id_name.keys():
|
|||
|
type_id_name.update({type_id: type_name})
|
|||
|
type_id_parent.update({type_id: type_parent})
|
|||
|
return type_id_name, type_id_parent
|
|||
|
|
|||
|
def trans_VegePhenology_measdata_dic(self, meas_data, cuted_ori_sim_path):
|
|||
|
"""
|
|||
|
获取多边形区域内所有的点,分为训练集数据和测试集数据
|
|||
|
:para meas_data: csv读取的实测数据
|
|||
|
"""
|
|||
|
train_data = []
|
|||
|
test_data = []
|
|||
|
type_data = {}
|
|||
|
|
|||
|
for data in meas_data:
|
|||
|
data_use_type = data[0]
|
|||
|
sar_img_name = data[1]
|
|||
|
name = sar_img_name.rstrip('.tar.gz')
|
|||
|
|
|||
|
if data_use_type == 'train':
|
|||
|
phenology_id = int(data[2])
|
|||
|
phenology_name = data[3]
|
|||
|
if phenology_id not in type_data.keys():
|
|||
|
type_data.update({phenology_id: phenology_name})
|
|||
|
else:
|
|||
|
phenology_id = -1
|
|||
|
|
|||
|
pointList = self.__roiPolygonAnalysis(data[4])
|
|||
|
l1a_points = []
|
|||
|
for points in pointList:
|
|||
|
roi_poly = [(float(lon), float(lat)) for (lon, lat) in points]
|
|||
|
tr = TransImgL1A(cuted_ori_sim_path, roi_poly)
|
|||
|
l1a_points = tr.get_roi_points()
|
|||
|
# l1a_points = tr.get_lonlat_points()
|
|||
|
if data_use_type == 'train':
|
|||
|
train_data.append([name, phenology_id, l1a_points, type_data[phenology_id]])
|
|||
|
elif data_use_type == 'test':
|
|||
|
test_data.append([name, phenology_id, l1a_points])
|
|||
|
type_map = []
|
|||
|
for n, id in zip(range(len(type_data)), type_data):
|
|||
|
type_map.append([n + 1, id, type_data[id]])
|
|||
|
|
|||
|
return train_data, test_data, type_map
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def vegePhenology_class_list(csv_path):
|
|||
|
"""
|
|||
|
输出csv表中的前三列
|
|||
|
"""
|
|||
|
reader = csv.reader(open(csv_path, newline=''))
|
|||
|
class_list=[]
|
|||
|
type_id_name = {}
|
|||
|
for line_data in reader:
|
|||
|
class_list.append(line_data) # class_list含有四列
|
|||
|
for data in class_list[1:]:
|
|||
|
type_id = data[2]
|
|||
|
type_name = data[3]
|
|||
|
|
|||
|
if type_id not in type_id_name.keys():
|
|||
|
if type_id.strip() != "":
|
|||
|
type_id_name.update({type_id: type_name})
|
|||
|
return type_id_name
|
|||
|
|
|||
|
# if __name__ == '__main__':
|
|||
|
# csvh = csvHandle()
|
|||
|
# csv_path = r"I:\preprocessed\VegetationPhenologyMeasureData_E118.9_N31.4.csv"
|
|||
|
# data = csvh.trans_VegePhenology_measdata_dic(csvh.readcsv(csv_path),r"I:\preprocessed\GF3_SAY_QPSI_011444_E118.9_N31.4_20181012_L1A_AHV_L10003515422_RPC_ori_sim_preprocessed.tif")
|
|||
|
# pass
|