# -*- coding: UTF-8 -*- """ @Project:SalinityMain.py @File:AuxData.py @Function: 读取指标物候辅助数据 @Contact: @Author:SHJ @Date:2021/10/21 19:25 @Version:1.0.0 """ import csv import numpy as np import mahotas import logging from tool.algorithm.algtools.CoordinateTransformation import lonlat2geo, geo2imagexy from tool.algorithm.image.ImageHandle import ImageHandler logger = logging.getLogger("mylog") class PhenologyMeasCsv: def __init__(self, csv_path, preprocessed_paras): self.__csv_path = csv_path self.__preprocessed_paras = preprocessed_paras def api_read_phenology_measure(self): csv_data = self.__readcsv(self.__csv_path) return self.__trans_measuredata(csv_data) def class_list(self): """ 输出csv表中的前三列 """ reader = csv.reader(open(self.__csv_path, newline='')) class_list=[] type_id_name = {} for line_data in reader: class_list.append(line_data) # class_list含有四列 for data in class_list[1:]: type_id = data[2] type_name = data[3] if type_id not in type_id_name.keys(): if type_id.strip() != "": type_id_name.update({type_id: type_name}) return type_id_name pass @staticmethod def __readcsv(csv_path): """ 读取csv表格数据 :para csv_path: csv文件路径 """ reader = csv.reader(open(csv_path, newline='')) csv_list = [] for line_data in reader: csv_list.append(line_data) return csv_list[1:] @staticmethod def readcsv_dic(csv_path): """ 读取csv表格数据 :para csv_path: csv文件路径 """ reader = csv.reader(open(csv_path, newline='')) csv_list = [] for line_data in reader: csv_list.append(line_data) keys = csv_list[0] datas = csv_list[1:] data_dic = {} for key ,n in zip(keys,range(len(keys))): values = [data for data in datas[n]] data_dic.update({key: values}) return data_dic def __trans_measuredata(self, meas_data): """ 获取多边形区域内所有的点,分为训练集数据和测试集数据 :para meas_data: csv读取的实测数据 """ train_data = [] test_data = [] type_data = {} for data in meas_data: point_list = [] data_use_type = data[0] sar_img_name = data[1] name = sar_img_name.rstrip('.tar.gz') dataset, rows, cols = self.__get_para_tif_inf(name) if data_use_type == 'train': phenology_id = int(data[2]) phenology_name = data[3] if phenology_id not in type_data.keys(): type_data.update({phenology_id: phenology_name}) else: phenology_id = -1 pointList = self.__roiPolygonAnalysis(data[4]) for points in pointList: poly = [] for point in points: lon = float(point[0]) lat = float(point[1]) projs = lonlat2geo(dataset, lon, lat) coord = geo2imagexy(dataset, projs[1], projs[0]) row = round(coord[1]) col = round(coord[0]) if 0 <= row < rows and 0 <= col < cols: poly.append([row, col]) else: logger.warning("measure data: %s is beyond tif scope !", data) point_list = point_list + self.__render(poly) if data_use_type == 'train': train_data.append([name, phenology_id, point_list]) elif data_use_type == 'test': test_data.append([name, phenology_id, point_list]) type_map = [] for n, id in zip(range(len(type_data)), type_data): type_map.append([n + 1, id, type_data[id]]) return train_data, test_data, type_map @staticmethod def __render(poly): # https://www.cnpython.com/qa/51516 """Return polygon as grid of points inside polygon. Input : poly (list of lists) Output : output (list of lists) """ xs, ys = zip(*poly) minx, maxx = min(xs), max(xs) miny, maxy = min(ys), max(ys) newPoly = [(int(x - minx), int(y - miny)) for (x, y) in poly] X = maxx - minx + 1 Y = maxy - miny + 1 grid = np.zeros((X, Y), dtype=np.int8) mahotas.polygon.fill_polygon(newPoly, grid) return [(x + minx, y + miny) for (x, y) in zip(*np.nonzero(grid))] def __get_para_tif_inf(self, tif_name): """ 获取影像的信息 :para tif_name: 影像名称 """ tif_path = self.__preprocessed_paras[tif_name + '_HH'] ih = ImageHandler() dataset = ih.get_dataset(tif_path) rows = ih.get_img_height(tif_path) cols = ih.get_img_width(tif_path) return dataset, rows, cols @staticmethod def __roiPolygonAnalysis(roiStr): """ 将csv的POLY数据转为数组 :para roiStr: poly数据 :return pointList: 保存多边形的list """ pointList = [] strContent = roiStr.replace("POLYGON", "") # 解析轮廓字符串为二维数组 bracketsList = [] strTemp = '' strList = [] for c in strContent: if c == '(': bracketsList.append(c) continue elif c == ')': if len(bracketsList) > 0: bracketsList.pop(0) if len(strTemp) > 0: strList.append(strTemp) strTemp = '' else: strTemp += c for item in strList: if len(item) == 0: continue pTempList = item.split(',') pList = [] for row in pTempList: cells = row.split(' ') if len(cells) != 2: continue point = [] point.append(float(cells[0])) point.append(float(cells[1])) pList.append(point) pointList.append(pList) return pointList