add extract_freature.py
parent
e792747357
commit
0fb1483359
|
|
@ -0,0 +1,591 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
DINOv3特征提取脚本 - 原生分辨率处理
|
||||
支持图像上采样并保证1 patch对应16×16 pixels
|
||||
|
||||
任务: task_20251204_P3921_upsampling_features
|
||||
输入: TestImages/P3921.png (2044×1896)
|
||||
上采样: 4x → 8176×7584
|
||||
模型: DINOv3-ViT-7B/16-SAT
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Dict, Optional
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from sklearn.decomposition import PCA
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import cm
|
||||
from tqdm import tqdm
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
class NativeResolutionFeatureExtractor:
|
||||
"""原生分辨率特征提取器
|
||||
|
||||
关键特性:
|
||||
- 完全手动预处理,绕过AutoImageProcessor的自动resize
|
||||
- 分块处理,支持大尺寸图像
|
||||
- 保证1 patch = 16×16 pixels
|
||||
- 显存优化,逐块处理并清理GPU缓存
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
tile_size: int = 512,
|
||||
overlap: int = 64,
|
||||
gpu_id: int = 0
|
||||
):
|
||||
"""初始化特征提取器
|
||||
|
||||
Args:
|
||||
model_path: DINOv3模型权重路径
|
||||
device: 计算设备 ("cuda" 或 "cpu")
|
||||
tile_size: 分块大小 (推荐512)
|
||||
overlap: 块之间的重叠像素数 (推荐64)
|
||||
gpu_id: 指定GPU设备ID (默认: 0)
|
||||
"""
|
||||
self.device = device
|
||||
self.tile_size = tile_size
|
||||
self.overlap = overlap
|
||||
self.model_path = model_path
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# 设置GPU设备
|
||||
if self.device == "cuda":
|
||||
# 检查CUDA可用性
|
||||
if not torch.cuda.is_available():
|
||||
print("WARNING: CUDA not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
else:
|
||||
# 设置默认GPU设备
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(self.gpu_id)} (ID: {self.gpu_id})")
|
||||
|
||||
# ImageNet标准化参数
|
||||
self.mean = np.array([0.485, 0.456, 0.406])
|
||||
self.std = np.array([0.229, 0.224, 0.225])
|
||||
|
||||
# 加载模型
|
||||
print(f"Loading DINOv3 model from: {model_path}")
|
||||
self.model = self._load_model()
|
||||
self.model.eval()
|
||||
print(f"Model loaded successfully on {self.device}")
|
||||
|
||||
def _load_model(self):
|
||||
"""加载DINOv3模型"""
|
||||
from transformers import AutoModel
|
||||
|
||||
# 使用AutoModel自动识别正确的模型类
|
||||
model = AutoModel.from_pretrained(
|
||||
self.model_path,
|
||||
local_files_only=True,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 将模型移动到指定设备
|
||||
if self.device == "cuda":
|
||||
model = model.cuda(self.gpu_id)
|
||||
else:
|
||||
model = model.to(self.device)
|
||||
|
||||
return model
|
||||
|
||||
def preprocess_manual(self, image_pil: Image.Image) -> torch.Tensor:
|
||||
"""手动预处理图像,完全控制尺寸
|
||||
|
||||
绕过AutoImageProcessor,保证原生分辨率
|
||||
"""
|
||||
# 转为RGB
|
||||
if image_pil.mode != 'RGB':
|
||||
image_pil = image_pil.convert('RGB')
|
||||
|
||||
# 转为numpy数组并归一化到[0, 1]
|
||||
img_array = np.array(image_pil).astype(np.float32) / 255.0
|
||||
|
||||
# ImageNet标准化
|
||||
img_array = (img_array - self.mean) / self.std
|
||||
|
||||
# 转为torch tensor: [H, W, C] -> [C, H, W] -> [1, C, H, W]
|
||||
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
# 移动到指定设备
|
||||
if self.device == "cuda":
|
||||
return img_tensor.cuda(self.gpu_id)
|
||||
else:
|
||||
return img_tensor.to(self.device)
|
||||
|
||||
def extract_features_tiled(
|
||||
self,
|
||||
image_tensor: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""分块提取特征
|
||||
|
||||
将大图分成tile_size×tile_size的块,逐块处理并融合
|
||||
|
||||
Returns:
|
||||
features: [1, num_patches, feature_dim] 特征张量
|
||||
num_patches_h: 高度方向patch数量
|
||||
num_patches_w: 宽度方向patch数量
|
||||
"""
|
||||
_, _, H, W = image_tensor.shape
|
||||
|
||||
# 计算总patch数量 (16×16一个patch)
|
||||
num_patches_h = H // 16
|
||||
num_patches_w = W // 16
|
||||
|
||||
print(f"Image size: {W}×{H}")
|
||||
print(f"Patch grid: {num_patches_w}×{num_patches_h} = {num_patches_w * num_patches_h} patches")
|
||||
|
||||
# 计算tile参数
|
||||
stride = self.tile_size - self.overlap
|
||||
n_tiles_h = (H - self.overlap + stride - 1) // stride
|
||||
n_tiles_w = (W - self.overlap + stride - 1) // stride
|
||||
|
||||
print(f"Processing with {n_tiles_w}×{n_tiles_h} = {n_tiles_w * n_tiles_h} tiles")
|
||||
print(f"Tile size: {self.tile_size}×{self.tile_size}, overlap: {self.overlap}")
|
||||
|
||||
# 初始化特征累积图和权重图
|
||||
feature_dim = 4096 # ViT-7B/16的特征维度
|
||||
feature_map = torch.zeros(
|
||||
(1, num_patches_h, num_patches_w, feature_dim),
|
||||
dtype=torch.float32
|
||||
)
|
||||
weight_map = torch.zeros(
|
||||
(1, num_patches_h, num_patches_w, 1),
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
# 生成tile的权重图 (中心权重高,边缘权重低)
|
||||
tile_weight = self._create_tile_weight(self.tile_size // 16, self.overlap // 16)
|
||||
|
||||
# 逐块处理
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(n_tiles_h), desc="Processing tiles"):
|
||||
for j in range(n_tiles_w):
|
||||
# 计算tile边界
|
||||
y_start = i * stride
|
||||
x_start = j * stride
|
||||
y_end = min(y_start + self.tile_size, H)
|
||||
x_end = min(x_start + self.tile_size, W)
|
||||
|
||||
# 提取tile
|
||||
tile = image_tensor[:, :, y_start:y_end, x_start:x_end]
|
||||
|
||||
# 如果tile不是标准大小,需要padding (使用constant 0填充)
|
||||
tile_h, tile_w = tile.shape[2], tile.shape[3]
|
||||
if tile_h != self.tile_size or tile_w != self.tile_size:
|
||||
pad_h = self.tile_size - tile_h
|
||||
pad_w = self.tile_size - tile_w
|
||||
# padding顺序: (left, right, top, bottom)
|
||||
tile = F.pad(tile, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||
|
||||
# 通过模型提取特征
|
||||
outputs = self.model(tile, output_hidden_states=False)
|
||||
|
||||
# DINOv3模型返回: [CLS token, register tokens (4个), patch tokens]
|
||||
# 需要去掉CLS token和register tokens
|
||||
num_register_tokens = 4
|
||||
tile_features = outputs.last_hidden_state[:, 1 + num_register_tokens:, :] # 去掉CLS和register tokens
|
||||
|
||||
# Reshape为2D特征图
|
||||
tile_patch_h = self.tile_size // 16
|
||||
tile_patch_w = self.tile_size // 16
|
||||
tile_features = tile_features.reshape(1, tile_patch_h, tile_patch_w, feature_dim)
|
||||
|
||||
# 如果有padding,裁剪掉padding部分
|
||||
actual_patch_h = tile_h // 16
|
||||
actual_patch_w = tile_w // 16
|
||||
tile_features = tile_features[:, :actual_patch_h, :actual_patch_w, :]
|
||||
|
||||
# 计算特征图中的位置
|
||||
patch_y_start = y_start // 16
|
||||
patch_x_start = x_start // 16
|
||||
patch_y_end = patch_y_start + actual_patch_h
|
||||
patch_x_end = patch_x_start + actual_patch_w
|
||||
|
||||
# 调整tile权重大小
|
||||
tile_w_adj = tile_weight[:actual_patch_h, :actual_patch_w].unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
# 累积特征和权重
|
||||
feature_map[:, patch_y_start:patch_y_end, patch_x_start:patch_x_end, :] += \
|
||||
tile_features.cpu() * tile_w_adj
|
||||
weight_map[:, patch_y_start:patch_y_end, patch_x_start:patch_x_end, :] += \
|
||||
tile_w_adj
|
||||
|
||||
# 清理GPU缓存
|
||||
if self.device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 归一化特征 (加权平均)
|
||||
feature_map = feature_map / (weight_map + 1e-8)
|
||||
|
||||
# Reshape为 [1, num_patches, feature_dim]
|
||||
features = feature_map.reshape(1, num_patches_h * num_patches_w, feature_dim)
|
||||
|
||||
return features, num_patches_h, num_patches_w
|
||||
|
||||
def _create_tile_weight(self, tile_patch_size: int, overlap_patches: int) -> torch.Tensor:
|
||||
"""创建tile权重图
|
||||
|
||||
中心区域权重为1.0,边缘区域线性衰减
|
||||
"""
|
||||
weight = torch.ones((tile_patch_size, tile_patch_size))
|
||||
|
||||
if overlap_patches > 0:
|
||||
# 创建线性衰减
|
||||
fade = torch.linspace(0, 1, overlap_patches)
|
||||
|
||||
# 上边缘
|
||||
weight[:overlap_patches, :] *= fade.unsqueeze(1)
|
||||
# 下边缘
|
||||
weight[-overlap_patches:, :] *= fade.flip(0).unsqueeze(1)
|
||||
# 左边缘
|
||||
weight[:, :overlap_patches] *= fade.unsqueeze(0)
|
||||
# 右边缘
|
||||
weight[:, -overlap_patches:] *= fade.flip(0).unsqueeze(0)
|
||||
|
||||
return weight
|
||||
|
||||
def upsample_image(self, image_pil: Image.Image, scale: int) -> Image.Image:
|
||||
"""上采样图像
|
||||
|
||||
使用高质量的双三次插值
|
||||
"""
|
||||
if scale == 1:
|
||||
return image_pil
|
||||
|
||||
w, h = image_pil.size
|
||||
new_w, new_h = w * scale, h * scale
|
||||
|
||||
print(f"Upsampling image: {w}×{h} -> {new_w}×{new_h} ({scale}x)")
|
||||
|
||||
upsampled = image_pil.resize(
|
||||
(new_w, new_h),
|
||||
Image.BICUBIC
|
||||
)
|
||||
|
||||
return upsampled
|
||||
|
||||
def generate_pca_visualization(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
num_patches_h: int,
|
||||
num_patches_w: int,
|
||||
output_size: Tuple[int, int]
|
||||
) -> np.ndarray:
|
||||
"""生成PCA彩色可视化
|
||||
|
||||
将高维特征降维到RGB三通道
|
||||
"""
|
||||
print("Generating PCA visualization...")
|
||||
|
||||
# Reshape特征: [1, num_patches, dim] -> [num_patches, dim]
|
||||
features_np = features.squeeze(0).cpu().numpy()
|
||||
|
||||
# PCA降维到3个主成分
|
||||
pca = PCA(n_components=3)
|
||||
features_pca = pca.fit_transform(features_np)
|
||||
|
||||
print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")
|
||||
|
||||
# Reshape为图像: [num_patches, 3] -> [H, W, 3]
|
||||
pca_image = features_pca.reshape(num_patches_h, num_patches_w, 3)
|
||||
|
||||
# 归一化到[0, 1]
|
||||
pca_image = (pca_image - pca_image.min()) / (pca_image.max() - pca_image.min())
|
||||
|
||||
# 计算实际对应的像素尺寸 (patch数 × 16)
|
||||
actual_h = num_patches_h * 16
|
||||
actual_w = num_patches_w * 16
|
||||
|
||||
# 上采样到实际patch对应的分辨率
|
||||
pca_tensor = torch.from_numpy(pca_image).permute(2, 0, 1).unsqueeze(0).float()
|
||||
pca_upsampled = F.interpolate(
|
||||
pca_tensor,
|
||||
size=(actual_h, actual_w),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
pca_upsampled = pca_upsampled.squeeze(0).permute(1, 2, 0).numpy()
|
||||
|
||||
# 转为uint8
|
||||
pca_uint8 = (pca_upsampled * 255).astype(np.uint8)
|
||||
|
||||
return pca_uint8
|
||||
|
||||
def generate_heatmap(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
num_patches_h: int,
|
||||
num_patches_w: int,
|
||||
output_size: Tuple[int, int],
|
||||
center_point: Optional[Tuple[int, int]] = None
|
||||
) -> np.ndarray:
|
||||
"""生成中心点热力图
|
||||
|
||||
显示各区域与中心点的特征相似度
|
||||
"""
|
||||
print("Generating center point heatmap...")
|
||||
|
||||
# 默认使用图像中心点
|
||||
if center_point is None:
|
||||
center_point = (num_patches_h // 2, num_patches_w // 2)
|
||||
|
||||
print(f"Using center point: patch ({center_point[1]}, {center_point[0]})")
|
||||
|
||||
# Reshape特征: [1, num_patches, dim] -> [H, W, dim]
|
||||
features_map = features.squeeze(0).reshape(num_patches_h, num_patches_w, -1)
|
||||
|
||||
# 获取中心点特征
|
||||
center_feature = features_map[center_point[0], center_point[1]]
|
||||
|
||||
# 计算余弦相似度
|
||||
features_flat = features_map.reshape(-1, features_map.shape[-1])
|
||||
center_feature_norm = center_feature / (torch.norm(center_feature) + 1e-8)
|
||||
features_norm = features_flat / (torch.norm(features_flat, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
similarity = torch.matmul(features_norm, center_feature_norm.unsqueeze(-1)).squeeze(-1)
|
||||
similarity_map = similarity.reshape(num_patches_h, num_patches_w)
|
||||
|
||||
# 转为numpy
|
||||
similarity_np = similarity_map.cpu().numpy()
|
||||
|
||||
# 归一化到[0, 1]
|
||||
similarity_np = (similarity_np - similarity_np.min()) / (similarity_np.max() - similarity_np.min())
|
||||
|
||||
# 计算实际对应的像素尺寸 (patch数 × 16)
|
||||
actual_h = num_patches_h * 16
|
||||
actual_w = num_patches_w * 16
|
||||
|
||||
# 上采样到实际patch对应的分辨率
|
||||
similarity_tensor = torch.from_numpy(similarity_np).unsqueeze(0).unsqueeze(0).float()
|
||||
similarity_upsampled = F.interpolate(
|
||||
similarity_tensor,
|
||||
size=(actual_h, actual_w),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
similarity_upsampled = similarity_upsampled.squeeze().numpy()
|
||||
|
||||
# 应用colormap
|
||||
cmap = cm.get_cmap('inferno')
|
||||
heatmap = cmap(similarity_upsampled)[:, :, :3] # 去掉alpha通道
|
||||
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
|
||||
|
||||
return heatmap_uint8
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
upscale: int = 4
|
||||
):
|
||||
"""完整的特征提取和可视化流程"""
|
||||
|
||||
# 创建输出目录
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("="*60)
|
||||
print(f"DINOv3 Native Resolution Feature Extraction")
|
||||
print("="*60)
|
||||
print(f"Input: {input_path}")
|
||||
print(f"Output: {output_dir}")
|
||||
print(f"Upscale: {upscale}x")
|
||||
print(f"Model: {self.model_path}")
|
||||
print(f"Device: {self.device}")
|
||||
if self.device == "cuda":
|
||||
print(f"GPU ID: {self.gpu_id}")
|
||||
print("="*60)
|
||||
|
||||
# 1. 加载并上采样图像
|
||||
print("\n[1/5] Loading and upsampling image...")
|
||||
image_pil = Image.open(input_path)
|
||||
print(f"Original size: {image_pil.size[0]}×{image_pil.size[1]}")
|
||||
|
||||
upsampled_image = self.upsample_image(image_pil, upscale)
|
||||
W, H = upsampled_image.size
|
||||
print(f"Processing size: {W}×{H}")
|
||||
|
||||
# 保存上采样后的图像
|
||||
upsampled_path = output_path / f"upsampled_{upscale}x.png"
|
||||
upsampled_image.save(upsampled_path)
|
||||
print(f"Saved upsampled image to: {upsampled_path}")
|
||||
|
||||
# 2. 预处理
|
||||
print("\n[2/5] Preprocessing image...")
|
||||
image_tensor = self.preprocess_manual(upsampled_image)
|
||||
print(f"Input tensor shape: {image_tensor.shape}")
|
||||
|
||||
# 3. 提取特征
|
||||
print("\n[3/5] Extracting features...")
|
||||
start_time = time.time()
|
||||
features, num_patches_h, num_patches_w = self.extract_features_tiled(image_tensor)
|
||||
extraction_time = time.time() - start_time
|
||||
print(f"Feature extraction completed in {extraction_time:.2f}s")
|
||||
print(f"Feature shape: {features.shape}")
|
||||
|
||||
# 验证patch数量
|
||||
expected_patches_h = H // 16
|
||||
expected_patches_w = W // 16
|
||||
assert num_patches_h == expected_patches_h, \
|
||||
f"Patch height mismatch: {num_patches_h} != {expected_patches_h}"
|
||||
assert num_patches_w == expected_patches_w, \
|
||||
f"Patch width mismatch: {num_patches_w} != {expected_patches_w}"
|
||||
print(f"✓ Patch alignment verified: 1 patch = 16×16 pixels")
|
||||
|
||||
# 4. 生成PCA可视化
|
||||
print("\n[4/5] Generating PCA visualization...")
|
||||
pca_image = self.generate_pca_visualization(
|
||||
features, num_patches_h, num_patches_w, (H, W)
|
||||
)
|
||||
pca_path = output_path / "pca_rainbow.png"
|
||||
Image.fromarray(pca_image).save(pca_path)
|
||||
print(f"Saved PCA visualization to: {pca_path}")
|
||||
print(f"PCA image size: {pca_image.shape[1]}×{pca_image.shape[0]}")
|
||||
|
||||
# 5. 生成热力图
|
||||
print("\n[5/5] Generating heatmap...")
|
||||
heatmap_image = self.generate_heatmap(
|
||||
features, num_patches_h, num_patches_w, (H, W)
|
||||
)
|
||||
heatmap_path = output_path / "pca_heatmap.png"
|
||||
Image.fromarray(heatmap_image).save(heatmap_path)
|
||||
print(f"Saved heatmap to: {heatmap_path}")
|
||||
print(f"Heatmap size: {heatmap_image.shape[1]}×{heatmap_image.shape[0]}")
|
||||
|
||||
# 6. 保存元信息
|
||||
print("\nSaving metadata...")
|
||||
metadata = {
|
||||
"input_image": str(input_path),
|
||||
"original_size": {"width": image_pil.size[0], "height": image_pil.size[1]},
|
||||
"upscale_factor": upscale,
|
||||
"processing_size": {"width": W, "height": H},
|
||||
"model": self.model_path,
|
||||
"device": self.device,
|
||||
"gpu_id": self.gpu_id if self.device == "cuda" else None,
|
||||
"patch_grid": {"height": num_patches_h, "width": num_patches_w},
|
||||
"num_patches": num_patches_h * num_patches_w,
|
||||
"feature_dim": features.shape[-1],
|
||||
"tile_size": self.tile_size,
|
||||
"overlap": self.overlap,
|
||||
"extraction_time_seconds": extraction_time,
|
||||
"patch_size_pixels": 16,
|
||||
"verification": "1 patch = 16×16 pixels ✓"
|
||||
}
|
||||
|
||||
metadata_path = output_path / "summary.json"
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
print(f"Saved metadata to: {metadata_path}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Processing completed successfully!")
|
||||
print("="*60)
|
||||
print(f"\nOutput files:")
|
||||
print(f" - Upsampled image: {upsampled_path}")
|
||||
print(f" - PCA visualization: {pca_path}")
|
||||
print(f" - Heatmap: {heatmap_path}")
|
||||
print(f" - Metadata: {metadata_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="DINOv3 Native Resolution Feature Extraction"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input image path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="/data/yaoju/model_dinoV3/dinov3-vit7b16-pretrain-sat493m",
|
||||
help="DINOv3 model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upscale",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Upscale factor (default: 4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Tile size for processing (default: 512)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Tile overlap in pixels (default: 64)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "cpu"],
|
||||
help="Device to use (default: cuda)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU device ID to use (default: 0)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查CUDA可用性
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
print("WARNING: CUDA not available, falling back to CPU")
|
||||
args.device = "cpu"
|
||||
|
||||
# 检查GPU ID有效性
|
||||
if args.device == "cuda":
|
||||
if args.gpu_id < 0 or args.gpu_id >= torch.cuda.device_count():
|
||||
print(f"WARNING: GPU ID {args.gpu_id} is invalid. Available GPUs: {torch.cuda.device_count()}")
|
||||
print(f"Using default GPU ID: 0")
|
||||
args.gpu_id = 0
|
||||
|
||||
# 创建提取器并运行
|
||||
extractor = NativeResolutionFeatureExtractor(
|
||||
model_path=args.model,
|
||||
device=args.device,
|
||||
tile_size=args.tile_size,
|
||||
overlap=args.overlap,
|
||||
gpu_id=args.gpu_id
|
||||
)
|
||||
|
||||
extractor.process(
|
||||
input_path=args.input,
|
||||
output_dir=args.output,
|
||||
upscale=args.upscale
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,591 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
DINOv3特征提取脚本 - 原生分辨率处理
|
||||
支持图像上采样并保证1 patch对应16×16 pixels
|
||||
|
||||
任务: task_20251204_P3921_upsampling_features
|
||||
输入: TestImages/P3921.png (2044×1896)
|
||||
上采样: 4x → 8176×7584
|
||||
模型: DINOv3-ViT-7B/16-SAT
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Dict, Optional
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from sklearn.decomposition import PCA
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import cm
|
||||
from tqdm import tqdm
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
class NativeResolutionFeatureExtractor:
|
||||
"""原生分辨率特征提取器
|
||||
|
||||
关键特性:
|
||||
- 完全手动预处理,绕过AutoImageProcessor的自动resize
|
||||
- 分块处理,支持大尺寸图像
|
||||
- 保证1 patch = 16×16 pixels
|
||||
- 显存优化,逐块处理并清理GPU缓存
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
tile_size: int = 512,
|
||||
overlap: int = 64,
|
||||
gpu_id: int = 0
|
||||
):
|
||||
"""初始化特征提取器
|
||||
|
||||
Args:
|
||||
model_path: DINOv3模型权重路径
|
||||
device: 计算设备 ("cuda" 或 "cpu")
|
||||
tile_size: 分块大小 (推荐512)
|
||||
overlap: 块之间的重叠像素数 (推荐64)
|
||||
gpu_id: 指定GPU设备ID (默认: 0)
|
||||
"""
|
||||
self.device = device
|
||||
self.tile_size = tile_size
|
||||
self.overlap = overlap
|
||||
self.model_path = model_path
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# 设置GPU设备
|
||||
if self.device == "cuda":
|
||||
# 检查CUDA可用性
|
||||
if not torch.cuda.is_available():
|
||||
print("WARNING: CUDA not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
else:
|
||||
# 设置默认GPU设备
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(self.gpu_id)} (ID: {self.gpu_id})")
|
||||
|
||||
# ImageNet标准化参数
|
||||
self.mean = np.array([0.485, 0.456, 0.406])
|
||||
self.std = np.array([0.229, 0.224, 0.225])
|
||||
|
||||
# 加载模型
|
||||
print(f"Loading DINOv3 model from: {model_path}")
|
||||
self.model = self._load_model()
|
||||
self.model.eval()
|
||||
print(f"Model loaded successfully on {self.device}")
|
||||
|
||||
def _load_model(self):
|
||||
"""加载DINOv3模型"""
|
||||
from transformers import AutoModel
|
||||
|
||||
# 使用AutoModel自动识别正确的模型类
|
||||
model = AutoModel.from_pretrained(
|
||||
self.model_path,
|
||||
local_files_only=True,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 将模型移动到指定设备
|
||||
if self.device == "cuda":
|
||||
model = model.cuda(self.gpu_id)
|
||||
else:
|
||||
model = model.to(self.device)
|
||||
|
||||
return model
|
||||
|
||||
def preprocess_manual(self, image_pil: Image.Image) -> torch.Tensor:
|
||||
"""手动预处理图像,完全控制尺寸
|
||||
|
||||
绕过AutoImageProcessor,保证原生分辨率
|
||||
"""
|
||||
# 转为RGB
|
||||
if image_pil.mode != 'RGB':
|
||||
image_pil = image_pil.convert('RGB')
|
||||
|
||||
# 转为numpy数组并归一化到[0, 1]
|
||||
img_array = np.array(image_pil).astype(np.float32) / 255.0
|
||||
|
||||
# ImageNet标准化
|
||||
img_array = (img_array - self.mean) / self.std
|
||||
|
||||
# 转为torch tensor: [H, W, C] -> [C, H, W] -> [1, C, H, W]
|
||||
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
# 移动到指定设备
|
||||
if self.device == "cuda":
|
||||
return img_tensor.cuda(self.gpu_id)
|
||||
else:
|
||||
return img_tensor.to(self.device)
|
||||
|
||||
def extract_features_tiled(
|
||||
self,
|
||||
image_tensor: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""分块提取特征
|
||||
|
||||
将大图分成tile_size×tile_size的块,逐块处理并融合
|
||||
|
||||
Returns:
|
||||
features: [1, num_patches, feature_dim] 特征张量
|
||||
num_patches_h: 高度方向patch数量
|
||||
num_patches_w: 宽度方向patch数量
|
||||
"""
|
||||
_, _, H, W = image_tensor.shape
|
||||
|
||||
# 计算总patch数量 (16×16一个patch)
|
||||
num_patches_h = H // 16
|
||||
num_patches_w = W // 16
|
||||
|
||||
print(f"Image size: {W}×{H}")
|
||||
print(f"Patch grid: {num_patches_w}×{num_patches_h} = {num_patches_w * num_patches_h} patches")
|
||||
|
||||
# 计算tile参数
|
||||
stride = self.tile_size - self.overlap
|
||||
n_tiles_h = (H - self.overlap + stride - 1) // stride
|
||||
n_tiles_w = (W - self.overlap + stride - 1) // stride
|
||||
|
||||
print(f"Processing with {n_tiles_w}×{n_tiles_h} = {n_tiles_w * n_tiles_h} tiles")
|
||||
print(f"Tile size: {self.tile_size}×{self.tile_size}, overlap: {self.overlap}")
|
||||
|
||||
# 初始化特征累积图和权重图
|
||||
feature_dim = 4096 # ViT-7B/16的特征维度
|
||||
feature_map = torch.zeros(
|
||||
(1, num_patches_h, num_patches_w, feature_dim),
|
||||
dtype=torch.float32
|
||||
)
|
||||
weight_map = torch.zeros(
|
||||
(1, num_patches_h, num_patches_w, 1),
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
# 生成tile的权重图 (中心权重高,边缘权重低)
|
||||
tile_weight = self._create_tile_weight(self.tile_size // 16, self.overlap // 16)
|
||||
|
||||
# 逐块处理
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(n_tiles_h), desc="Processing tiles"):
|
||||
for j in range(n_tiles_w):
|
||||
# 计算tile边界
|
||||
y_start = i * stride
|
||||
x_start = j * stride
|
||||
y_end = min(y_start + self.tile_size, H)
|
||||
x_end = min(x_start + self.tile_size, W)
|
||||
|
||||
# 提取tile
|
||||
tile = image_tensor[:, :, y_start:y_end, x_start:x_end]
|
||||
|
||||
# 如果tile不是标准大小,需要padding (使用constant 0填充)
|
||||
tile_h, tile_w = tile.shape[2], tile.shape[3]
|
||||
if tile_h != self.tile_size or tile_w != self.tile_size:
|
||||
pad_h = self.tile_size - tile_h
|
||||
pad_w = self.tile_size - tile_w
|
||||
# padding顺序: (left, right, top, bottom)
|
||||
tile = F.pad(tile, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||
|
||||
# 通过模型提取特征
|
||||
outputs = self.model(tile, output_hidden_states=False)
|
||||
|
||||
# DINOv3模型返回: [CLS token, register tokens (4个), patch tokens]
|
||||
# 需要去掉CLS token和register tokens
|
||||
num_register_tokens = 4
|
||||
tile_features = outputs.last_hidden_state[:, 1 + num_register_tokens:, :] # 去掉CLS和register tokens
|
||||
|
||||
# Reshape为2D特征图
|
||||
tile_patch_h = self.tile_size // 16
|
||||
tile_patch_w = self.tile_size // 16
|
||||
tile_features = tile_features.reshape(1, tile_patch_h, tile_patch_w, feature_dim)
|
||||
|
||||
# 如果有padding,裁剪掉padding部分
|
||||
actual_patch_h = tile_h // 16
|
||||
actual_patch_w = tile_w // 16
|
||||
tile_features = tile_features[:, :actual_patch_h, :actual_patch_w, :]
|
||||
|
||||
# 计算特征图中的位置
|
||||
patch_y_start = y_start // 16
|
||||
patch_x_start = x_start // 16
|
||||
patch_y_end = patch_y_start + actual_patch_h
|
||||
patch_x_end = patch_x_start + actual_patch_w
|
||||
|
||||
# 调整tile权重大小
|
||||
tile_w_adj = tile_weight[:actual_patch_h, :actual_patch_w].unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
# 累积特征和权重
|
||||
feature_map[:, patch_y_start:patch_y_end, patch_x_start:patch_x_end, :] += \
|
||||
tile_features.cpu() * tile_w_adj
|
||||
weight_map[:, patch_y_start:patch_y_end, patch_x_start:patch_x_end, :] += \
|
||||
tile_w_adj
|
||||
|
||||
# 清理GPU缓存
|
||||
if self.device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 归一化特征 (加权平均)
|
||||
feature_map = feature_map / (weight_map + 1e-8)
|
||||
|
||||
# Reshape为 [1, num_patches, feature_dim]
|
||||
features = feature_map.reshape(1, num_patches_h * num_patches_w, feature_dim)
|
||||
|
||||
return features, num_patches_h, num_patches_w
|
||||
|
||||
def _create_tile_weight(self, tile_patch_size: int, overlap_patches: int) -> torch.Tensor:
|
||||
"""创建tile权重图
|
||||
|
||||
中心区域权重为1.0,边缘区域线性衰减
|
||||
"""
|
||||
weight = torch.ones((tile_patch_size, tile_patch_size))
|
||||
|
||||
if overlap_patches > 0:
|
||||
# 创建线性衰减
|
||||
fade = torch.linspace(0, 1, overlap_patches)
|
||||
|
||||
# 上边缘
|
||||
weight[:overlap_patches, :] *= fade.unsqueeze(1)
|
||||
# 下边缘
|
||||
weight[-overlap_patches:, :] *= fade.flip(0).unsqueeze(1)
|
||||
# 左边缘
|
||||
weight[:, :overlap_patches] *= fade.unsqueeze(0)
|
||||
# 右边缘
|
||||
weight[:, -overlap_patches:] *= fade.flip(0).unsqueeze(0)
|
||||
|
||||
return weight
|
||||
|
||||
def upsample_image(self, image_pil: Image.Image, scale: int) -> Image.Image:
|
||||
"""上采样图像
|
||||
|
||||
使用高质量的双三次插值
|
||||
"""
|
||||
if scale == 1:
|
||||
return image_pil
|
||||
|
||||
w, h = image_pil.size
|
||||
new_w, new_h = w * scale, h * scale
|
||||
|
||||
print(f"Upsampling image: {w}×{h} -> {new_w}×{new_h} ({scale}x)")
|
||||
|
||||
upsampled = image_pil.resize(
|
||||
(new_w, new_h),
|
||||
Image.BICUBIC
|
||||
)
|
||||
|
||||
return upsampled
|
||||
|
||||
def generate_pca_visualization(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
num_patches_h: int,
|
||||
num_patches_w: int,
|
||||
output_size: Tuple[int, int]
|
||||
) -> np.ndarray:
|
||||
"""生成PCA彩色可视化
|
||||
|
||||
将高维特征降维到RGB三通道
|
||||
"""
|
||||
print("Generating PCA visualization...")
|
||||
|
||||
# Reshape特征: [1, num_patches, dim] -> [num_patches, dim]
|
||||
features_np = features.squeeze(0).cpu().numpy()
|
||||
|
||||
# PCA降维到3个主成分
|
||||
pca = PCA(n_components=3)
|
||||
features_pca = pca.fit_transform(features_np)
|
||||
|
||||
print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")
|
||||
|
||||
# Reshape为图像: [num_patches, 3] -> [H, W, 3]
|
||||
pca_image = features_pca.reshape(num_patches_h, num_patches_w, 3)
|
||||
|
||||
# 归一化到[0, 1]
|
||||
pca_image = (pca_image - pca_image.min()) / (pca_image.max() - pca_image.min())
|
||||
|
||||
# 计算实际对应的像素尺寸 (patch数 × 16)
|
||||
actual_h = num_patches_h * 16
|
||||
actual_w = num_patches_w * 16
|
||||
|
||||
# 上采样到实际patch对应的分辨率
|
||||
pca_tensor = torch.from_numpy(pca_image).permute(2, 0, 1).unsqueeze(0).float()
|
||||
pca_upsampled = F.interpolate(
|
||||
pca_tensor,
|
||||
size=(actual_h, actual_w),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
pca_upsampled = pca_upsampled.squeeze(0).permute(1, 2, 0).numpy()
|
||||
|
||||
# 转为uint8
|
||||
pca_uint8 = (pca_upsampled * 255).astype(np.uint8)
|
||||
|
||||
return pca_uint8
|
||||
|
||||
def generate_heatmap(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
num_patches_h: int,
|
||||
num_patches_w: int,
|
||||
output_size: Tuple[int, int],
|
||||
center_point: Optional[Tuple[int, int]] = None
|
||||
) -> np.ndarray:
|
||||
"""生成中心点热力图
|
||||
|
||||
显示各区域与中心点的特征相似度
|
||||
"""
|
||||
print("Generating center point heatmap...")
|
||||
|
||||
# 默认使用图像中心点
|
||||
if center_point is None:
|
||||
center_point = (num_patches_h // 2, num_patches_w // 2)
|
||||
|
||||
print(f"Using center point: patch ({center_point[1]}, {center_point[0]})")
|
||||
|
||||
# Reshape特征: [1, num_patches, dim] -> [H, W, dim]
|
||||
features_map = features.squeeze(0).reshape(num_patches_h, num_patches_w, -1)
|
||||
|
||||
# 获取中心点特征
|
||||
center_feature = features_map[center_point[0], center_point[1]]
|
||||
|
||||
# 计算余弦相似度
|
||||
features_flat = features_map.reshape(-1, features_map.shape[-1])
|
||||
center_feature_norm = center_feature / (torch.norm(center_feature) + 1e-8)
|
||||
features_norm = features_flat / (torch.norm(features_flat, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
similarity = torch.matmul(features_norm, center_feature_norm.unsqueeze(-1)).squeeze(-1)
|
||||
similarity_map = similarity.reshape(num_patches_h, num_patches_w)
|
||||
|
||||
# 转为numpy
|
||||
similarity_np = similarity_map.cpu().numpy()
|
||||
|
||||
# 归一化到[0, 1]
|
||||
similarity_np = (similarity_np - similarity_np.min()) / (similarity_np.max() - similarity_np.min())
|
||||
|
||||
# 计算实际对应的像素尺寸 (patch数 × 16)
|
||||
actual_h = num_patches_h * 16
|
||||
actual_w = num_patches_w * 16
|
||||
|
||||
# 上采样到实际patch对应的分辨率
|
||||
similarity_tensor = torch.from_numpy(similarity_np).unsqueeze(0).unsqueeze(0).float()
|
||||
similarity_upsampled = F.interpolate(
|
||||
similarity_tensor,
|
||||
size=(actual_h, actual_w),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
similarity_upsampled = similarity_upsampled.squeeze().numpy()
|
||||
|
||||
# 应用colormap
|
||||
cmap = cm.get_cmap('inferno')
|
||||
heatmap = cmap(similarity_upsampled)[:, :, :3] # 去掉alpha通道
|
||||
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
|
||||
|
||||
return heatmap_uint8
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
upscale: int = 4
|
||||
):
|
||||
"""完整的特征提取和可视化流程"""
|
||||
|
||||
# 创建输出目录
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("="*60)
|
||||
print(f"DINOv3 Native Resolution Feature Extraction")
|
||||
print("="*60)
|
||||
print(f"Input: {input_path}")
|
||||
print(f"Output: {output_dir}")
|
||||
print(f"Upscale: {upscale}x")
|
||||
print(f"Model: {self.model_path}")
|
||||
print(f"Device: {self.device}")
|
||||
if self.device == "cuda":
|
||||
print(f"GPU ID: {self.gpu_id}")
|
||||
print("="*60)
|
||||
|
||||
# 1. 加载并上采样图像
|
||||
print("\n[1/5] Loading and upsampling image...")
|
||||
image_pil = Image.open(input_path)
|
||||
print(f"Original size: {image_pil.size[0]}×{image_pil.size[1]}")
|
||||
|
||||
upsampled_image = self.upsample_image(image_pil, upscale)
|
||||
W, H = upsampled_image.size
|
||||
print(f"Processing size: {W}×{H}")
|
||||
|
||||
# 保存上采样后的图像
|
||||
upsampled_path = output_path / f"upsampled_{upscale}x.png"
|
||||
upsampled_image.save(upsampled_path)
|
||||
print(f"Saved upsampled image to: {upsampled_path}")
|
||||
|
||||
# 2. 预处理
|
||||
print("\n[2/5] Preprocessing image...")
|
||||
image_tensor = self.preprocess_manual(upsampled_image)
|
||||
print(f"Input tensor shape: {image_tensor.shape}")
|
||||
|
||||
# 3. 提取特征
|
||||
print("\n[3/5] Extracting features...")
|
||||
start_time = time.time()
|
||||
features, num_patches_h, num_patches_w = self.extract_features_tiled(image_tensor)
|
||||
extraction_time = time.time() - start_time
|
||||
print(f"Feature extraction completed in {extraction_time:.2f}s")
|
||||
print(f"Feature shape: {features.shape}")
|
||||
|
||||
# 验证patch数量
|
||||
expected_patches_h = H // 16
|
||||
expected_patches_w = W // 16
|
||||
assert num_patches_h == expected_patches_h, \
|
||||
f"Patch height mismatch: {num_patches_h} != {expected_patches_h}"
|
||||
assert num_patches_w == expected_patches_w, \
|
||||
f"Patch width mismatch: {num_patches_w} != {expected_patches_w}"
|
||||
print(f"✓ Patch alignment verified: 1 patch = 16×16 pixels")
|
||||
|
||||
# 4. 生成PCA可视化
|
||||
print("\n[4/5] Generating PCA visualization...")
|
||||
pca_image = self.generate_pca_visualization(
|
||||
features, num_patches_h, num_patches_w, (H, W)
|
||||
)
|
||||
pca_path = output_path / "pca_rainbow.png"
|
||||
Image.fromarray(pca_image).save(pca_path)
|
||||
print(f"Saved PCA visualization to: {pca_path}")
|
||||
print(f"PCA image size: {pca_image.shape[1]}×{pca_image.shape[0]}")
|
||||
|
||||
# 5. 生成热力图
|
||||
print("\n[5/5] Generating heatmap...")
|
||||
heatmap_image = self.generate_heatmap(
|
||||
features, num_patches_h, num_patches_w, (H, W)
|
||||
)
|
||||
heatmap_path = output_path / "pca_heatmap.png"
|
||||
Image.fromarray(heatmap_image).save(heatmap_path)
|
||||
print(f"Saved heatmap to: {heatmap_path}")
|
||||
print(f"Heatmap size: {heatmap_image.shape[1]}×{heatmap_image.shape[0]}")
|
||||
|
||||
# 6. 保存元信息
|
||||
print("\nSaving metadata...")
|
||||
metadata = {
|
||||
"input_image": str(input_path),
|
||||
"original_size": {"width": image_pil.size[0], "height": image_pil.size[1]},
|
||||
"upscale_factor": upscale,
|
||||
"processing_size": {"width": W, "height": H},
|
||||
"model": self.model_path,
|
||||
"device": self.device,
|
||||
"gpu_id": self.gpu_id if self.device == "cuda" else None,
|
||||
"patch_grid": {"height": num_patches_h, "width": num_patches_w},
|
||||
"num_patches": num_patches_h * num_patches_w,
|
||||
"feature_dim": features.shape[-1],
|
||||
"tile_size": self.tile_size,
|
||||
"overlap": self.overlap,
|
||||
"extraction_time_seconds": extraction_time,
|
||||
"patch_size_pixels": 16,
|
||||
"verification": "1 patch = 16×16 pixels ✓"
|
||||
}
|
||||
|
||||
metadata_path = output_path / "summary.json"
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
print(f"Saved metadata to: {metadata_path}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Processing completed successfully!")
|
||||
print("="*60)
|
||||
print(f"\nOutput files:")
|
||||
print(f" - Upsampled image: {upsampled_path}")
|
||||
print(f" - PCA visualization: {pca_path}")
|
||||
print(f" - Heatmap: {heatmap_path}")
|
||||
print(f" - Metadata: {metadata_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="DINOv3 Native Resolution Feature Extraction"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input image path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="/data/yaoju/model_dinoV3/dinov3-vit7b16-pretrain-sat493m",
|
||||
help="DINOv3 model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upscale",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Upscale factor (default: 4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Tile size for processing (default: 512)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Tile overlap in pixels (default: 64)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "cpu"],
|
||||
help="Device to use (default: cuda)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU device ID to use (default: 0)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查CUDA可用性
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
print("WARNING: CUDA not available, falling back to CPU")
|
||||
args.device = "cpu"
|
||||
|
||||
# 检查GPU ID有效性
|
||||
if args.device == "cuda":
|
||||
if args.gpu_id < 0 or args.gpu_id >= torch.cuda.device_count():
|
||||
print(f"WARNING: GPU ID {args.gpu_id} is invalid. Available GPUs: {torch.cuda.device_count()}")
|
||||
print(f"Using default GPU ID: 0")
|
||||
args.gpu_id = 0
|
||||
|
||||
# 创建提取器并运行
|
||||
extractor = NativeResolutionFeatureExtractor(
|
||||
model_path=args.model,
|
||||
device=args.device,
|
||||
tile_size=args.tile_size,
|
||||
overlap=args.overlap,
|
||||
gpu_id=args.gpu_id
|
||||
)
|
||||
|
||||
extractor.process(
|
||||
input_path=args.input,
|
||||
output_dir=args.output,
|
||||
upscale=args.upscale
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue