三种知识蒸馏的简单CIFAR-10示例

Peng Xia

知识蒸馏是一种技术,能够将大型、计算开销大的模型中的知识转移到较小的模型中,同时保持其有效性。这使得模型可以部署在计算能力较弱的硬件上,从而实现更快速和高效的推理。

知识蒸馏,简而言之,是让学生模型学习到类似教师模型某个阶段(即某些层)的结果,再搭配正常的与 labels 的监督学习,就可以对齐教师模型。以下是本文涉及到的不同层面的对齐方式的 loss 计算方式。

Loss 类型 分类 简要说明
Baseline(CE) —— 学生用常规 CrossEntropy 监督,不涉及教师。
KL 散度(Soft Targets) Logits-based KD 最常见的 Hinton KD(softmax+KL),对比教师和学生的 soft logits。
Cosine Embedding Loss Representation-based KD 对比教师和学生的某层输出(如 embedding)的方向一致性。
Intermediate Regressor Loss(MSE) Feature-based KD / Hint-based KD 让学生的中间层对齐教师某一层(通常是 MSE 或 L2),如可特征层或者MLP 层。
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# is available, and if not, use the CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
1
Using cuda device

CIFAR-10 是一个包含十个类别的常用图像数据集。我们的目标是为每张输入图像预测其所属的一个类别。

CIFAR-10

1
2
3
4
5
6
7
8
9
10
11
12
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
1
2
Files already downloaded and verified
Files already downloaded and verified

教师模型和学生模型的独立监督学习

以下两种架构都是卷积神经网络(CNN),具有不同数量的卷积层,用作特征提取器,后接一个包含10个类别的分类器。学生模型中的卷积核和参数数量较少。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
def __init__(self, num_classes=10):
super(DeepNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)

def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
def __init__(self, num_classes=10):
super(LightNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)

def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

以下是基础模型的训练和测试函数,loss 为 logit 与 label 之间的交叉熵,metric 也是 accuracy。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def train(model, train_loader, epochs, learning_rate, device):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()

for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
# inputs: A collection of batch_size images
# labels: A vector of dimensionality batch_size with integers denoting class of each image
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)

# outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
# labels: The actual labels of the images. Vector of dimensionality batch_size
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
model.to(device)
model.eval()

correct = 0
total = 0

with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)

outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy

NORMAL
设置 torch 随机种子,以方便结果复现

1
2
3
4
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)
1
2
3
4
5
6
7
8
9
10
11
Epoch 1/10, Loss: 1.4504150144584345
Epoch 2/10, Loss: 0.9319331475231044
Epoch 3/10, Loss: 0.7057780996917764
Epoch 4/10, Loss: 0.5456663286289596
Epoch 5/10, Loss: 0.40601869922159883
Epoch 6/10, Loss: 0.28496988056718237
Epoch 7/10, Loss: 0.20859538298814803
Epoch 8/10, Loss: 0.15263370982826213
Epoch 9/10, Loss: 0.1255439391069095
Epoch 10/10, Loss: 0.10921698193901869
Test Accuracy: 75.68%
1
2
3
4
5
6
7
8
9
10
11
# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
1
2
Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296

这里我们输出第一次的 norm 值,只是为了确保两个同种的不同模型初始化完全一样,这样一个直接监督学习,另一个知识蒸馏。

1
2
3
4
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
1
2
DeepNN parameters: 2,495,914
LightNN parameters: 267,738

学生模型参数量仅为教师模型的 1/10。

1
2
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)
1
2
3
4
5
6
7
8
9
10
11
Epoch 1/10, Loss: 1.470865885315039
Epoch 2/10, Loss: 1.1615625789098423
Epoch 3/10, Loss: 1.0311961670970673
Epoch 4/10, Loss: 0.9303452948780011
Epoch 5/10, Loss: 0.8560442997671455
Epoch 6/10, Loss: 0.790915090104808
Epoch 7/10, Loss: 0.727434201466153
Epoch 8/10, Loss: 0.672618583721273
Epoch 9/10, Loss: 0.6193401737286307
Epoch 10/10, Loss: 0.5703270701343751
Test Accuracy: 70.50%
1
2
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")
1
2
Teacher accuracy: 75.68%
Student accuracy: 70.50%

Logits-based KD

这里的 soft target 指的就是教师模型的 logit,相比 hard target (label) 是没有那么绝对的,所以称为 soft target。我们计算教师模型和学生模型的 logit 的 KL 散度以对齐两者的 logit。

Logits-based KD

假设:

  • 是教师模型的 logits
  • 是学生模型的 logits
  • 是温度系数(Temperature)
  • 是 ground-truth 标签(硬标签)
  1. Soft Loss(蒸馏项):KL 散度

教师的 softmax logits 用温度 平滑,学生也用相同温度。

  • log_softmax + softmax 配合 KLDivLoss 实现。
  • 越大,分布越平滑,能体现非主类的概率差异。
  1. Hard Loss(监督项):CrossEntropy

这是常规的分类损失,基于 ground-truth。

  1. 总 Loss

综合上述两项:

$$
\mathcal{L} = \alpha \cdot \mathcal{L}{\text{CE}} + (1 - \alpha) \cdot T^2 \cdot \mathcal{L}{\text{KD}}
$$

  • 是权重系数(常取 0.5 ~ 0.9)
  • 是补偿因子,因为 KL 散度在梯度反传时会除以 ,所以前向传播要乘回来。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)

teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode

for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits = teacher(inputs)

# Forward pass with the student model
student_logits = student(inputs)

#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
# KL(P || Q) = sum P * (log P - log Q)
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)

# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

loss.backward()
optimizer.step()

running_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
Epoch 1/10, Loss: 2.4518108270357333
Epoch 2/10, Loss: 1.9292995130924313
Epoch 3/10, Loss: 1.7038319656611098
Epoch 4/10, Loss: 1.5407711703454137
Epoch 5/10, Loss: 1.410411078606725
Epoch 6/10, Loss: 1.2932921629732528
Epoch 7/10, Loss: 1.1938754214959986
Epoch 8/10, Loss: 1.1060839129225981
Epoch 9/10, Loss: 1.0279690643100787
Epoch 10/10, Loss: 0.9521754057815922
Test Accuracy: 71.17%
Teacher accuracy: 75.68%
Student accuracy without teacher: 70.50%
Student accuracy with CE + KD: 71.17%

Representation-based KD

Representation-based KD

在神经网络中,除了主要目标之外,加入额外的损失函数是一种常见且简单的做法,有助于实现更好的泛化效果。现在我们尝试为学生模型添加一个目标,但这次关注的不是输出层,而是其隐状态。我们的目标是通过引入一个简单的 loss ,将教师模型的表示信息传递给学生模型。该 loss 的最小化意味着:在传递给分类器之前,展平后的 embedding 向量随着 loss 的降低变得更加相似。

这种做法主要是想学习到类似教师模型的 embedding,既然 embedding 也很大程度影响了 MLP,相似的 embedding 也有利于得到更好的分类结果。

在以下模型设计中,为了计算 embedding,我们修改模型,在输出 logit 时也将 embedding 输出。这样的做法很正常,例如 transformer 通常也会把 embedding 输出,而它是为了利用已经计算好的 KV value。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ModifiedDeepNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedDeepNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)

def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
return x, flattened_conv_output_after_pooling

# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)

def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x, flattened_conv_output

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
modified_nn_deep.load_state_dict(nn_deep.state_dict())

# Once again ensure the norm of the first layer is the same for both networks
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())

# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())
1
2
3
Norm of 1st layer for deep_nn: 7.297417163848877
Norm of 1st layer for modified_deep_nn: 7.297417163848877
Norm of 1st layer: 2.327361822128296

在这里,embedding 大小完全一样,这样两者可以直接计算 loss,不需要首先将两者大小对齐。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Create a sample input tensor
sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32

# Pass the input through the student
logits, hidden_representation = modified_nn_light(sample_input)

# Print the shapes of the tensors
print("Student logits shape:", logits.shape) # batch_size x total_classes
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

# Pass the input through the teacher
logits, hidden_representation = modified_nn_deep(sample_input)

# Print the shapes of the tensors
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size
1
2
3
4
Student logits shape: torch.Size([128, 10])
Student hidden representation shape: torch.Size([128, 1024])
Teacher logits shape: torch.Size([128, 10])
Teacher hidden representation shape: torch.Size([128, 1024])

我们接下来把 KL 散度换成 cosine embedding loss

cosine_embedding_loss

给定两个 embedding, 和 $$x_2yy=1$,两者 embedding 应该尽量相似,两者的 cosine 值应该尽量高,夹角最后为 0 。否则,cosine 应该尽量小,我们并不硬性要求,只要小于 Margin 值就可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
ce_loss = nn.CrossEntropyLoss()
cosine_loss = nn.CosineEmbeddingLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)

teacher.to(device)
student.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode

for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

# Forward pass with the teacher model and keep only the hidden representation
with torch.no_grad():
_, teacher_hidden_representation = teacher(inputs)

# Forward pass with the student model
student_logits, student_hidden_representation = student(inputs)

# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)

# Weighted sum of the two losses
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

loss.backward()
optimizer.step()

running_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def test_multiple_outputs(model, test_loader, device):
model.to(device)
model.eval()

correct = 0
total = 0

with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)

outputs, _ = model(inputs) # Disregard the second tensor of the tuple
_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
1
2
3
# Train and test the lightweight network with cross entropy loss
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
1
2
3
4
5
6
7
8
9
10
11
Epoch 1/10, Loss: 1.2930794695149297
Epoch 2/10, Loss: 1.0589119400209783
Epoch 3/10, Loss: 0.9568290230258346
Epoch 4/10, Loss: 0.8834206243915022
Epoch 5/10, Loss: 0.827963415466611
Epoch 6/10, Loss: 0.7844577648145769
Epoch 7/10, Loss: 0.7424624330552337
Epoch 8/10, Loss: 0.7057774653825004
Epoch 9/10, Loss: 0.6661281107026903
Epoch 10/10, Loss: 0.6384436357814027
Test Accuracy: 70.53%

Feature-based KD

feature kd

教师模型使用32个卷积核,学生模型使用16个卷积核,两者输出大小不一样,因此不可以直接对齐两者。因此我们将加入一个可训练的层,用于将学生模型的特征图转换为与教师模型特征图相同的形状。在实际操作中,我们会修改轻量级模型类,使其在一个中间 regressor 之后返回隐状态,该回归器用于匹配卷积特征图的尺寸;同时修改教师模型类,使其返回最后一个卷积层的输出。

1
2
3
4
5
6
7
# Pass the sample input only from the convolutional feature extractor
convolutional_fe_output_student = nn_light.features(sample_input)
convolutional_fe_output_teacher = nn_deep.features(sample_input)

# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)
1
2
Student's feature extractor output shape:  torch.Size([128, 16, 8, 8])
Teacher's feature extractor output shape: torch.Size([128, 32, 8, 8])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class ModifiedDeepNNRegressor(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedDeepNNRegressor, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)

def forward(self, x):
x = self.features(x)
conv_feature_map = x
x = torch.flatten(x, 1)
x = self.classifier(x)
return x, conv_feature_map

class ModifiedLightNNRegressor(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNRegressor, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# Include an extra regressor (in our case linear)
self.regressor = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, padding=1)
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)

def forward(self, x):
x = self.features(x)
regressor_output = self.regressor(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x, regressor_output

接下来,我们也不用 cosine embedding loss,改为直接让两者的 MSE loss,使得两者尽量相同,而之前的 cosine embedding loss 会尽量让 embedding 的角度尽量相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
ce_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)

teacher.to(device)
student.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode

for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

# Again ignore teacher logits
with torch.no_grad():
_, teacher_feature_map = teacher(inputs)

# Forward pass with the student model
student_logits, regressor_feature_map = student(inputs)

# Calculate the loss
hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)

# Weighted sum of the two losses
loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss

loss.backward()
optimizer.step()

running_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.

# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())

# Train and test once again
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
1
2
3
4
5
6
7
8
9
10
11
Epoch 1/10, Loss: 1.2503477350220351
Epoch 2/10, Loss: 0.9886605646604162
Epoch 3/10, Loss: 0.8816004340605967
Epoch 4/10, Loss: 0.7988157702224029
Epoch 5/10, Loss: 0.7375963549784688
Epoch 6/10, Loss: 0.6859350952955768
Epoch 7/10, Loss: 0.6347175088837324
Epoch 8/10, Loss: 0.5938715356237748
Epoch 9/10, Loss: 0.556459863015148
Epoch 10/10, Loss: 0.519292121576836
Test Accuracy: 70.37%
Comments