wandb 的使用方法和示例

Peng Xia

Weights & Biases (wandb)

使用 Weights & Biases (wandb) 可以非常方便地记录训练过程中的 loss、accuracy、模型权重、学习率曲线、超参配置、版本控制、模型可视化 等内容。

下面是使用 wandb 的完整入门步骤:


1. 安装 wandb

1
pip install wandb

2. 注册账号并登录

你需要在 https://wandb.ai 注册一个账号。注意现在只有注册成个人使用才免费

然后运行一次登录命令(只需一次):

1
wandb login

它会让你粘贴一个 token(注册账号后网页上会提供),输入后即登录成功。

示例输出如下:

1
2
3
4
5
6
7
8
$ wandb login
wandb: WARNING Using legacy-service, which is deprecated. If this is unintentional, you can fix it by ensuring you do not call `wandb.require('legacy-service')` and do not set the WANDB_X_REQUIRE_LEGACY_SERVICE environment variable.
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:
wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: /home/guest/.netrc
wandb: Currently logged in as: **YOUR_ACCOUNT** to https://api.wandb.ai. Use `wandb login --relogin` to force relogin

3. 基本使用方式(训练脚本中添加)

以下是在 PyTorch 训练中最基本的使用方式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import wandb

# ① 初始化项目
wandb.init(project="my_mnist_project", config={
"epochs": 5,
"batch_size": 64,
"lr": 1e-3,
})

# ② 使用 wandb.config 读取超参(可选)
config = wandb.config

# ③ 在每步/每个 epoch 记录指标
for epoch in range(config.epochs):
train_loss = ... # your training loop
acc = ...
# 这些参数可在网页上绘制曲线,x代表你记录的周期,可以是 step 或者 epoch
wandb.log({"loss": train_loss, "accuracy": acc, "epoch": epoch})

# ④ 保存模型(可选)
torch.save(model.state_dict(), "model.pt")
wandb.save("model.pt")

4. 执行带 wandb 的代码时的输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
wandb: Currently logged in **YOUR_ACCOUNT** to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.19.10
wandb: Run data is saved locally in **YOUR_LOCAL_SAVE_PATH**
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run royal-brook-1
wandb: ⭐️ View project at **YOUR_WANDB_URL**
wandb: 🚀 View run at **YOUR_WANDB_URL**
---
THis your executed file's output.
---
wandb:
wandb:
wandb: Run history:
wandb: epoch ▁▃▅▆█
wandb: epoch_time █▂▂▁▁
wandb: lr █▄▂▁▁
wandb: test_acc ▁▅▇██
wandb: test_loss █▄▂▁▁
wandb: train_loss █▂▁▁▁
wandb:
wandb: Run summary:
wandb: epoch 5
wandb: epoch_time 9.86671
wandb: lr 6e-05
wandb: test_acc 0.9867
wandb: test_loss 0.03622
wandb: train_loss 0.05843
wandb:
wandb: 🚀 View run comfy-wildflower-2 at: **YOUR_WANDB_URL**
wandb: ⭐️ View project at: **YOUR_WANDB_URL**
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20250503_060654-lnl2z1vr/logs

MNIST 使用示例

以下是一个自定义 Trainer 训练 MNIST 分类器的代码,wandb部分都在主函数里,尽量避免放在 Trainer 内部。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from time import time
import argparse
from tqdm import tqdm
import wandb


def compute_accuracy(preds: torch.Tensor, target: torch.Tensor, from_logits: bool = True) -> float:
"""
计算准确率。

参数:
- preds: 模型输出,形状为 [batch_size, num_classes] 或 [batch_size](如果已经是类标)
- target: 真实标签,形状为 [batch_size]
- from_logits: 若为 True,表示 preds 是 logits,需要取 argmax;若为 False,表示 preds 已是预测标签

返回:
- 准确率(float)
"""
if from_logits:
pred_labels = torch.argmax(preds, dim=1)
else:
pred_labels = preds

correct = (pred_labels == target).sum().item()
total = target.numel()

return correct / total


class ConvNet(nn.Module):

def __init__(self):
super(ConvNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)

def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
output = F.log_softmax(x, dim=1)
return output


# 自定义Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
image, label = self.data[idx]
return image, label


class Trainer:

def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
gpu_id: int=0,
) -> None:
self.gpu_id = gpu_id
self.model = model.to(gpu_id)
self.optimizer = optimizer


def save_checkpoint(self):
ckp = self.model.state_dict()
PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"Training checkpoint saved at {PATH}")


def _train_batch(self, source, targets):

self.model.train()
self.optimizer.zero_grad()

source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)

output = self.model(source)
loss = F.cross_entropy(output, targets)

loss.backward()
self.optimizer.step()

batch_loss = loss.detach().item()
return batch_loss


def _test_batch(self, source, targets):

self.model.eval()

with torch.no_grad():

source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)

output = self.model(source)
loss = F.cross_entropy(output, targets)

batch_loss = loss.detach().item()
accuracy = compute_accuracy(output, targets)

return batch_loss, accuracy


def train_epoch(self, train_dataloader):

total_loss = 0.0
num_batches = 0

for source, targets in tqdm(train_dataloader, desc="Training", mininterval=1):

loss = self._train_batch(source, targets)
total_loss += loss
num_batches += 1

avg_loss = total_loss / num_batches
return avg_loss


def test_epoch(self, test_dataloader):

total_loss = 0.0
num_batches = 0
total_accuracy = 0.0

for source, targets in tqdm(test_dataloader, desc="Testing", mininterval=1):

loss, accuracy = self._test_batch(source, targets)
total_loss += loss
total_accuracy += accuracy
num_batches += 1

avg_loss = total_loss / num_batches
avg_accuracy = total_accuracy / num_batches
return avg_loss, avg_accuracy



def prepare_dataset():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST(
root = './mnist',
train=True, # 设置True为训练数据,False为测试数据
transform = transform,
# download=True # 设置True后就自动下载,下载完成后改为False即可
)

train_set = MyDataset(train_data)

test_data = datasets.MNIST(
root = './mnist',
train=False, # 设置True为训练数据,False为测试数据
transform = transform,
)

test_set = MyDataset(test_data)

return train_set, test_set



def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
)


def arg_parser():
parser = argparse.ArgumentParser(description="MNIST Training Script")
parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")
parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")
parser.add_argument("--lr_decay_step_num", type=int, default=1, help="Step size for learning rate decay")
parser.add_argument("--lr_decay_factor", type=float, default=0.5, help="Factor by which learning rate is decayed")
parser.add_argument("--cuda_id", type=int, default=0, help="CUDA device ID to use")
parser.add_argument('--save_every', type=int, default=1, help='How often to save a snapshot')
return parser.parse_args()




if __name__ == "__main__":

args = arg_parser()
print(f"Training arguments: {args}")

EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
LR = args.lr
LR_DECAY_STEP_NUM = args.lr_decay_step_num
LR_DECAY_FACTOR = args.lr_decay_factor
CUDA_ID = args.cuda_id
DEVICE = torch.device(f"cuda:{CUDA_ID}")
SAVE_EVERY = args.save_every


# 初始化 wandb
wandb_run = wandb.init(
project="mnist", # 可自定义项目名称
config={
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"lr": LR,
"lr_decay_step_num": LR_DECAY_STEP_NUM,
"lr_decay_factor": LR_DECAY_FACTOR,
}
)

# prepare dataloader
train_set, test_set = prepare_dataset()
print(f"Train dataset size: {len(train_set)}")
print(f"Test dataset size: {len(test_set)}")
train_dataloader = prepare_dataloader(dataset=train_set, batch_size=BATCH_SIZE)
test_dataloader = prepare_dataloader(dataset=test_set, batch_size=BATCH_SIZE)

# prepare model
model = ConvNet()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)

# init trainer
trainer = Trainer(model, optimizer, CUDA_ID)

for epoch in range(EPOCHS):

print(f"Epoch {epoch+1}/{EPOCHS}")
print(f'Learning Rate: {scheduler.get_last_lr()[0]}')

start_time = time()

train_loss = trainer.train_epoch(train_dataloader)
test_loss, test_accuracy = trainer.test_epoch(test_dataloader)
print(f"Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

epoch_time = time() - start_time
current_lr = scheduler.get_last_lr()[0]

wandb_run.log({
"epoch": epoch + 1,
"train_loss": train_loss,
"test_loss": test_loss,
"test_acc": test_accuracy,
"lr": current_lr,
"epoch_time": epoch_time
})

scheduler.step()

if (epoch + 1) % SAVE_EVERY == 0:
trainer.save_checkpoint()

wandb_run.finish()

完整的 wandb 输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
Training arguments: Namespace(epochs=5, batch_size=512, lr=0.0005, lr_decay_step_num=1, lr_decay_factor=0.5, cuda_id=0, save_every=1)
wandb: Currently logged in as: **PERSONAL_INFO** to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.19.10
wandb: Run data is saved locally in **PERSONAL_INFO**
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run proud-oath-4
wandb: ⭐️ View project at **PERSONAL_INFO**
wandb: 🚀 View run at **PERSONAL_INFO**
Train dataset size: 60000
Test dataset size: 10000
Epoch 1/5
Learning Rate: 0.001
Training: 100%|████████████████████████████████████████████| 118/118 [00:08<00:00, 13.34it/s]
Testing: 100%|███████████████████████████████████████████████| 20/20 [00:01<00:00, 14.14it/s]
Train Loss: 0.3594, Test Loss: 0.0741, Test Accuracy: 0.9763
Training checkpoint saved at checkpoint.pt
Epoch 2/5
Learning Rate: 0.0005
Training: 100%|████████████████████████████████████████████| 118/118 [00:08<00:00, 13.74it/s]
Testing: 100%|███████████████████████████████████████████████| 20/20 [00:01<00:00, 14.53it/s]
Train Loss: 0.1085, Test Loss: 0.0536, Test Accuracy: 0.9831
Training checkpoint saved at checkpoint.pt
Epoch 3/5
Learning Rate: 0.00025
Training: 100%|████████████████████████████████████████████| 118/118 [00:08<00:00, 13.90it/s]
Testing: 100%|███████████████████████████████████████████████| 20/20 [00:01<00:00, 14.62it/s]
Train Loss: 0.0849, Test Loss: 0.0450, Test Accuracy: 0.9848
Training checkpoint saved at checkpoint.pt
Epoch 4/5
Learning Rate: 0.000125
Training: 100%|████████████████████████████████████████████| 118/118 [00:08<00:00, 13.85it/s]
Testing: 100%|███████████████████████████████████████████████| 20/20 [00:01<00:00, 14.46it/s]
Train Loss: 0.0737, Test Loss: 0.0398, Test Accuracy: 0.9870
Training checkpoint saved at checkpoint.pt
Epoch 5/5
Learning Rate: 6.25e-05
Training: 100%|████████████████████████████████████████████| 118/118 [00:08<00:00, 13.54it/s]
Testing: 100%|███████████████████████████████████████████████| 20/20 [00:01<00:00, 14.60it/s]
Train Loss: 0.0680, Test Loss: 0.0391, Test Accuracy: 0.9870
Training checkpoint saved at checkpoint.pt
wandb:
wandb:
wandb: Run history:
wandb: epoch ▁▃▅▆█
wandb: epoch_time █▃▁▂▅
wandb: lr █▄▂▁▁
wandb: test_acc ▁▅▇██
wandb: test_loss █▄▂▁▁
wandb: train_loss █▂▁▁▁
wandb:
wandb: Run summary:
wandb: epoch 5
wandb: epoch_time 10.08319
wandb: lr 6e-05
wandb: test_acc 0.98698
wandb: test_loss 0.03908
wandb: train_loss 0.068
wandb:
wandb: 🚀 View run proud-oath-4 at: **PERSONAL_INFO**
wandb: ⭐️ View project at: **PERSONAL_INFO**
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20250503_063703-b6zlso17/logs
Comments