tongbuxiugai

pull/13/head
chenzenghui 2025-03-21 16:04:39 +08:00
parent 2b586efe03
commit d1806f7427
1 changed files with 33 additions and 11 deletions

View File

@ -353,18 +353,40 @@ extern "C" void FreeCUDAHost(void* ptr) {
ptr = nullptr; ptr = nullptr;
} }
// GPU参数内存声明 // 多GPU内存分配函数
extern "C" void* mallocCUDADevice(size_t memsize) { void* mallocCUDADevice(size_t memsize, int device_id = 0)
void* ptr; {
cudaMalloc(&ptr, memsize); void* ptr = nullptr;
#ifdef __CUDADEBUG__ cudaError_t err;
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { // 1. 检查设备ID有效性
printf("mallocCUDADevice CUDA Error: %s, malloc memory : %d byte\n", cudaGetErrorString(err), memsize); int num_devices;
exit(2); cudaGetDeviceCount(&num_devices);
if (device_id < 0 || device_id >= num_devices)
{
printfinfo("Invalid device ID: %d\n", device_id);
return nullptr;
} }
#endif // __CUDADEBUG__
cudaDeviceSynchronize(); // 2. 切换目标GPU设备
err = cudaSetDevice(device_id);
if (err != cudaSuccess)
{
PrintLasterError("cudaSetDevice");
return nullptr;
}
// 3. 分配显存
err = cudaMalloc(&ptr, memsize);
if (err != cudaSuccess)
{
PrintLasterError("cudaMalloc");
return nullptr;
}
// 4. 可选:同步设备(视需求决定是否保留)
// cudaDeviceSynchronize();
return ptr; return ptr;
} }