diff --git a/docs/pytorch.md b/docs/pytorch.md index 39939d64397..8e188bc4038 100644 --- a/docs/pytorch.md +++ b/docs/pytorch.md @@ -198,6 +198,41 @@ torch.Size([1, 2, 28]) # 参数代表在哪里添加维度 0 torch.Size([2, 28, 1]) ``` +Cuda 相关 +--- +### 检查 Cuda 是否可用 +```python +>>> import torch.cuda +>>> torch.cuda.is_available() +>>> True +``` +### 列出 GPU 设备 +```python +import torch +device_count = torch.cuda.device_count() +print("CUDA 设备") +for i in range(device_count): + device_name = torch.cuda.get_device_name(i) + total_memory = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) + print(f"├── 设备 {i}: {device_name}, 容量: {total_memory:.2f} GiB") +print("└── (结束)") +``` +### 将模型、张量等数据在 GPU 和内存之间进行搬运 +```python +import torch +# Replace 0 to your GPU device index. or use "cuda" directly. +device = f"cuda:0" +# Move to GPU +tensor_m = torch.tensor([1, 2, 3]) +tensor_g = tensor_m.to(device) +model_m = torch.nn.Linear(1, 1) +model_g = model_m.to(device) +# Move back. +tensor_m = tensor_g.cpu() +model_m = model_g.cpu() +``` + + 导入 Imports ---