2024-05-17 06:19:22 +00:00
# -*- coding: UTF-8 -*-
"""
@Project : microproduct
@File : csvHandle . py
@Function : 读写csv文件
@Contact :
@Author : SHJ
@Date : 2022 / 11 / 6
@Version : 1.0 .0
"""
import random
import csv
import logging
import numpy as np
from tool . algorithm . image . ImageHandle import ImageHandler
from tool . algorithm . algtools . CoordinateTransformation import geo2imagexy
from tool . algorithm . transforml1a . transHandle import TransImgL1A
logger = logging . getLogger ( " mylog " )
class csvHandle :
def __init__ ( self , row = 0 , col = 0 ) :
self . imageHandler = ImageHandler ( )
self . row = row
self . col = col
self . img_falg = False
if row != 0 and col != 0 :
self . roi_img = np . zeros ( ( row , col ) , dtype = float )
self . img_falg = True
def get_roi_img ( self ) :
if self . img_falg :
self . roi_img [ self . roi_img == 0 ] = np . nan
return self . roi_img
else :
return np . array ( [ ] )
@staticmethod
def readcsv ( csv_path ) :
reader = csv . reader ( open ( csv_path , newline = ' ' ) )
csv_list = [ ]
for line_data in reader :
csv_list . append ( line_data )
return csv_list [ 1 : ]
def trans_measuredata ( self , meas_data , tif_path ) :
file_name = tif_path
dataset = self . imageHandler . get_dataset ( file_name )
rows = self . imageHandler . get_img_height ( file_name )
cols = self . imageHandler . get_img_width ( file_name )
measdata_list = [ ]
logger . info ( ' [MEASURE DATA] ' )
for data in meas_data :
lon = float ( data [ 1 ] )
lat = float ( data [ 2 ] )
coord = geo2imagexy ( dataset , lon , lat )
row = round ( coord [ 1 ] )
col = round ( coord [ 0 ] )
if row > = 0 and row < = rows and col > = 0 and col < = cols :
measdata_list . append ( [ row , col , float ( data [ 3 ] ) ] )
logger . info ( [ row , col , float ( data [ 3 ] ) ] )
else :
logger . warning ( " measure data: %s is beyond tif scope ! " , data )
pass
return measdata_list
def write_roi_img_data ( self , points , type_id ) :
if self . img_falg :
for p in points :
r = p [ 0 ]
c = p [ 1 ]
if r < self . row and c < self . col :
self . roi_img [ r , c ] = type_id
def trans_landCover_measuredata ( self , meas_data , cuted_ori_sim_path , max_train_num = 100000 ) :
"""
获取多边形区域内所有的点 , 分为训练集数据和测试集数据
: para meas_data : csv读取的实测数据
"""
type_data = { }
n = 1
train_data_list = [ ]
for data in meas_data :
for d in data :
if d == ' ' :
raise Exception ( ' there are empty data! ' , data )
type_id = int ( data [ 1 ] )
type_name = data [ 2 ]
if type_id not in type_data . keys ( ) :
train_data_list . append ( [ n , type_id , type_name , [ ] ] )
type_data . update ( { type_id : type_name } )
n + = 1
pointList = self . __roiPolygonAnalysis ( data [ 3 ] )
for points in pointList :
roi_poly = [ ( float ( lon ) , float ( lat ) ) for ( lon , lat ) in points ]
tr = TransImgL1A ( cuted_ori_sim_path , roi_poly )
if tr . _mask is not None :
points = tr . get_roi_points ( )
for train_data in train_data_list :
if train_data [ 1 ] == type_id :
train_data [ 3 ] + = points
self . write_roi_img_data ( points , type_id )
if train_data [ 3 ] == [ ] :
raise Exception ( ' there are empty data! ' , train_data )
if len ( train_data_list ) < = 1 :
raise Exception ( ' there is only one label type! ' , train_data_list )
for train_data in train_data_list :
logger . info ( str ( train_data [ 0 ] ) + " , " + str ( train_data [ 2 ] ) + " , " + " num: " + str ( len ( train_data [ 3 ] ) ) )
max_num = max_train_num
if ( len ( train_data [ 3 ] ) > max_num ) :
logger . info ( " max number = " + str ( max_num ) + " , random select " + str ( max_num ) + " point as train data! " )
train_data [ 3 ] = random . sample ( train_data [ 3 ] , max_num )
return train_data_list
def trans_landCover_measuredata_dic ( self , meas_data , cuted_ori_sim_path , max_train_num = 100000 ) :
train_data_list = self . trans_landCover_measuredata ( meas_data , cuted_ori_sim_path , max_train_num )
return self . trans_landCover_list2dic ( train_data_list )
@staticmethod
def trans_landCover_list2dic ( train_data_list ) :
ids = [ ]
class_ids = [ ]
ch_names = [ ]
positions = [ ]
for data in train_data_list :
2024-08-26 11:20:59 +00:00
if data [ 3 ] == [ ] :
continue
2024-05-17 06:19:22 +00:00
ids . append ( data [ 0 ] )
class_ids . append ( data [ 1 ] )
ch_names . append ( data [ 2 ] )
positions . append ( data [ 3 ] )
train_data_dic = { }
train_data_dic . update ( { " ids " : ids } )
train_data_dic . update ( { " class_ids " : class_ids } )
train_data_dic . update ( { " ch_names " : ch_names } )
train_data_dic . update ( { " positions " : positions } )
return train_data_dic
@staticmethod
def __roiPolygonAnalysis ( roiStr ) :
"""
将csv的POLY数据转为数组
: para roiStr : poly数据
: return pointList : 保存多边形的list
"""
pointList = [ ]
strContent = roiStr . replace ( " POLYGON " , " " )
# 解析轮廓字符串为二维数组
bracketsList = [ ]
strTemp = ' '
strList = [ ]
for c in strContent :
if c == ' ( ' :
bracketsList . append ( c )
continue
elif c == ' ) ' :
if len ( bracketsList ) > 0 :
bracketsList . pop ( 0 )
if len ( strTemp ) > 0 :
strList . append ( strTemp )
strTemp = ' '
else :
strTemp + = c
for item in strList :
if len ( item ) == 0 :
continue
pTempList = item . split ( ' , ' )
pList = [ ]
for row in pTempList :
cells = row . split ( ' ' )
if len ( cells ) != 2 :
continue
point = [ float ( cells [ 0 ] ) , float ( cells [ 1 ] ) ]
pList . append ( point )
pointList . append ( pList )
return pointList
def class_landcover_list ( self , csv_path ) :
"""
输出csv表中的前三列
"""
reader = csv . reader ( open ( csv_path , newline = ' ' ) )
class_list = [ ]
type_id_name = { }
type_id_parent = { }
for line_data in reader :
class_list . append ( line_data ) # class_list含有四列
for data in class_list [ 1 : ] :
type_parent = data [ 0 ]
type_id = int ( data [ 1 ] )
type_name = data [ 2 ]
if type_id not in type_id_name . keys ( ) :
type_id_name . update ( { type_id : type_name } )
type_id_parent . update ( { type_id : type_parent } )
return type_id_name , type_id_parent
def trans_VegePhenology_measdata_dic ( self , meas_data , cuted_ori_sim_path ) :
"""
获取多边形区域内所有的点 , 分为训练集数据和测试集数据
: para meas_data : csv读取的实测数据
"""
train_data = [ ]
test_data = [ ]
type_data = { }
for data in meas_data :
data_use_type = data [ 0 ]
sar_img_name = data [ 1 ]
name = sar_img_name . rstrip ( ' .tar.gz ' )
if data_use_type == ' train ' :
phenology_id = int ( data [ 2 ] )
phenology_name = data [ 3 ]
if phenology_id not in type_data . keys ( ) :
type_data . update ( { phenology_id : phenology_name } )
else :
phenology_id = - 1
pointList = self . __roiPolygonAnalysis ( data [ 4 ] )
l1a_points = [ ]
for points in pointList :
roi_poly = [ ( float ( lon ) , float ( lat ) ) for ( lon , lat ) in points ]
tr = TransImgL1A ( cuted_ori_sim_path , roi_poly )
l1a_points = tr . get_roi_points ( )
# l1a_points = tr.get_lonlat_points()
if data_use_type == ' train ' :
train_data . append ( [ name , phenology_id , l1a_points , type_data [ phenology_id ] ] )
elif data_use_type == ' test ' :
test_data . append ( [ name , phenology_id , l1a_points ] )
type_map = [ ]
for n , id in zip ( range ( len ( type_data ) ) , type_data ) :
type_map . append ( [ n + 1 , id , type_data [ id ] ] )
return train_data , test_data , type_map
@staticmethod
def vegePhenology_class_list ( csv_path ) :
"""
输出csv表中的前三列
"""
reader = csv . reader ( open ( csv_path , newline = ' ' ) )
class_list = [ ]
type_id_name = { }
for line_data in reader :
class_list . append ( line_data ) # class_list含有四列
for data in class_list [ 1 : ] :
type_id = data [ 2 ]
type_name = data [ 3 ]
if type_id not in type_id_name . keys ( ) :
if type_id . strip ( ) != " " :
type_id_name . update ( { type_id : type_name } )
return type_id_name
# if __name__ == '__main__':
# csvh = csvHandle()
# csv_path = r"I:\preprocessed\VegetationPhenologyMeasureData_E118.9_N31.4.csv"
# data = csvh.trans_VegePhenology_measdata_dic(csvh.readcsv(csv_path),r"I:\preprocessed\GF3_SAY_QPSI_011444_E118.9_N31.4_20181012_L1A_AHV_L10003515422_RPC_ori_sim_preprocessed.tif")
# pass