SIMOrthoProgram-Orth_LT1AB-.../Ortho/tool/csv/csvHandle.py

265 lines
9.4 KiB
Python
Raw Normal View History

# -*- coding: UTF-8 -*-
"""
@Project : microproduct
@File : csvHandle.py
@Function : 读写csv文件
@Contact :
@Author:SHJ
@Date:2022/11/6
@Version:1.0.0
"""
import random
import csv
import logging
import numpy as np
from tool.algorithm.image.ImageHandle import ImageHandler
from tool.algorithm.algtools.CoordinateTransformation import geo2imagexy
from tool.algorithm.transforml1a.transHandle import TransImgL1A
logger = logging.getLogger("mylog")
class csvHandle:
def __init__(self, row=0, col=0):
self.imageHandler = ImageHandler()
self.row = row
self.col = col
self.img_falg = False
if row != 0 and col != 0:
self.roi_img = np.zeros((row, col), dtype=float)
self.img_falg = True
def get_roi_img(self):
if self.img_falg:
self.roi_img[self.roi_img == 0] = np.nan
return self.roi_img
else:
return np.array([])
@staticmethod
def readcsv(csv_path):
reader = csv.reader(open(csv_path, newline=''))
csv_list = []
for line_data in reader:
csv_list.append(line_data)
return csv_list[1:]
def trans_measuredata(self, meas_data, tif_path):
file_name = tif_path
dataset = self.imageHandler.get_dataset(file_name)
rows = self.imageHandler.get_img_height(file_name)
cols = self.imageHandler.get_img_width(file_name)
measdata_list = []
logger.info('[MEASURE DATA]')
for data in meas_data:
lon = float(data[1])
lat = float(data[2])
coord = geo2imagexy(dataset, lon, lat)
row = round(coord[1])
col = round(coord[0])
if row >= 0 and row <= rows and col >= 0 and col <= cols:
measdata_list.append([row, col, float(data[3])])
logger.info([row, col, float(data[3])])
else:
logger.warning("measure data: %s is beyond tif scope !", data)
pass
return measdata_list
def write_roi_img_data(self, points, type_id):
if self.img_falg:
for p in points:
r = p[0]
c = p[1]
if r < self.row and c < self.col:
self.roi_img[r, c] = type_id
def trans_landCover_measuredata(self, meas_data, cuted_ori_sim_path, max_train_num =100000):
"""
获取多边形区域内所有的点分为训练集数据和测试集数据
:para meas_data: csv读取的实测数据
"""
type_data = {}
n = 1
train_data_list = []
for data in meas_data:
for d in data:
if d == '':
raise Exception('there are empty data!', data)
type_id = int(data[1])
type_name = data[2]
if type_id not in type_data.keys():
train_data_list.append([n, type_id, type_name, []])
type_data.update({type_id: type_name})
n += 1
pointList = self.__roiPolygonAnalysis(data[3])
for points in pointList:
roi_poly = [(float(lon), float(lat)) for (lon, lat) in points]
tr = TransImgL1A(cuted_ori_sim_path, roi_poly)
if tr._mask is not None:
points = tr.get_roi_points()
for train_data in train_data_list:
if train_data[1] == type_id:
train_data[3] += points
self.write_roi_img_data(points, type_id)
if train_data[3] == [] :
raise Exception('there are empty data!', train_data)
if len(train_data_list) <= 1:
raise Exception('there is only one label type!', train_data_list)
for train_data in train_data_list:
logger.info(str(train_data[0]) + "," + str(train_data[2]) + "," + "num:" + str(len(train_data[3])))
max_num = max_train_num
if (len(train_data[3]) > max_num):
logger.info("max number =" + str(max_num) + ", random select" + str(max_num) + " point as train data!")
train_data[3] = random.sample(train_data[3], max_num)
return train_data_list
def trans_landCover_measuredata_dic(self, meas_data, cuted_ori_sim_path,max_train_num=100000):
train_data_list = self.trans_landCover_measuredata(meas_data, cuted_ori_sim_path,max_train_num)
return self.trans_landCover_list2dic(train_data_list)
@staticmethod
def trans_landCover_list2dic(train_data_list):
ids = []
class_ids = []
ch_names = []
positions = []
for data in train_data_list:
ids.append(data[0])
class_ids.append(data[1])
ch_names.append(data[2])
positions.append(data[3])
train_data_dic = {}
train_data_dic.update({"ids": ids})
train_data_dic.update({"class_ids": class_ids})
train_data_dic.update({"ch_names": ch_names})
train_data_dic.update({"positions": positions})
return train_data_dic
@staticmethod
def __roiPolygonAnalysis(roiStr):
"""
将csv的POLY数据转为数组
:para roiStr: poly数据
:return pointList: 保存多边形的list
"""
pointList = []
strContent = roiStr.replace("POLYGON", "")
# 解析轮廓字符串为二维数组
bracketsList = []
strTemp = ''
strList = []
for c in strContent:
if c == '(':
bracketsList.append(c)
continue
elif c == ')':
if len(bracketsList) > 0:
bracketsList.pop(0)
if len(strTemp) > 0:
strList.append(strTemp)
strTemp = ''
else:
strTemp += c
for item in strList:
if len(item) == 0:
continue
pTempList = item.split(',')
pList = []
for row in pTempList:
cells = row.split(' ')
if len(cells) != 2:
continue
point = [float(cells[0]), float(cells[1])]
pList.append(point)
pointList.append(pList)
return pointList
def class_landcover_list(self, csv_path):
"""
输出csv表中的前三列
"""
reader = csv.reader(open(csv_path, newline=''))
class_list=[]
type_id_name = {}
type_id_parent = {}
for line_data in reader:
class_list.append(line_data) # class_list含有四列
for data in class_list[1:]:
type_parent= data[0]
type_id = int(data[1])
type_name = data[2]
if type_id not in type_id_name.keys():
type_id_name.update({type_id: type_name})
type_id_parent.update({type_id: type_parent})
return type_id_name, type_id_parent
def trans_VegePhenology_measdata_dic(self, meas_data, cuted_ori_sim_path):
"""
获取多边形区域内所有的点分为训练集数据和测试集数据
:para meas_data: csv读取的实测数据
"""
train_data = []
test_data = []
type_data = {}
for data in meas_data:
data_use_type = data[0]
sar_img_name = data[1]
name = sar_img_name.rstrip('.tar.gz')
if data_use_type == 'train':
phenology_id = int(data[2])
phenology_name = data[3]
if phenology_id not in type_data.keys():
type_data.update({phenology_id: phenology_name})
else:
phenology_id = -1
pointList = self.__roiPolygonAnalysis(data[4])
l1a_points = []
for points in pointList:
roi_poly = [(float(lon), float(lat)) for (lon, lat) in points]
tr = TransImgL1A(cuted_ori_sim_path, roi_poly)
l1a_points = tr.get_roi_points()
# l1a_points = tr.get_lonlat_points()
if data_use_type == 'train':
train_data.append([name, phenology_id, l1a_points, type_data[phenology_id]])
elif data_use_type == 'test':
test_data.append([name, phenology_id, l1a_points])
type_map = []
for n, id in zip(range(len(type_data)), type_data):
type_map.append([n + 1, id, type_data[id]])
return train_data, test_data, type_map
@staticmethod
def vegePhenology_class_list(csv_path):
"""
输出csv表中的前三列
"""
reader = csv.reader(open(csv_path, newline=''))
class_list=[]
type_id_name = {}
for line_data in reader:
class_list.append(line_data) # class_list含有四列
for data in class_list[1:]:
type_id = data[2]
type_name = data[3]
if type_id not in type_id_name.keys():
if type_id.strip() != "":
type_id_name.update({type_id: type_name})
return type_id_name
# if __name__ == '__main__':
# csvh = csvHandle()
# csv_path = r"I:\preprocessed\VegetationPhenologyMeasureData_E118.9_N31.4.csv"
# data = csvh.trans_VegePhenology_measdata_dic(csvh.readcsv(csv_path),r"I:\preprocessed\GF3_SAY_QPSI_011444_E118.9_N31.4_20181012_L1A_AHV_L10003515422_RPC_ori_sim_preprocessed.tif")
# pass