ViT - vision transformer

Peng Xia

ViT (vision transformer) 这个模型的重点在于对特征的提炼能力,预训练只用简单的 softmax 作为 分类头,使得特征提取尽量优化,而不是更复杂的分类头

BG

将自注意力机制直接应用于图像时,需要每个像素与所有其他像素进行交互。这种全连接的方式导致计算复杂度随像素数量呈二次增长,因此难以扩展到实际图像尺寸。为了解决这一问题,以下多种策略在图像处理中应用 Transformer :

  • 局部像素注意力(Local Attention):只在邻近像素之间施加注意力,从而构建局部的多头自注意力模块,可在一定程度上替代卷积操作。
  • 全局注意力近似(Global Attention Approximation):如 Sparse Transformer 提出了可扩展的稀疏注意力机制,使得 Transformer 能够处理大规模图像输入。
  • 基于图像块的注意力(Patch-wise Attention):将输入图像划分为固定大小的图像块(如 2×2 patch),在每个块之间施加完整的自注意力机制。Cordonnier 等人的工作即采用了这种策略,并在较小图像上进行了实验。

ViT(Vision Transformer)本质上延续了最后一种思路,但对其进行了尺度扩展。相比于早期方法仅处理 2×2 的小块,ViT采用更大尺寸的 patch,从而能够适应中等分辨率的图像任务,实现了更强的表达能力和更广的应用范围。

BERT

因为 ViT 底层设计和 BERT 很类似,因此先讲下 BERT 的基本原理。

BERT 的 input是一条文本。文本中的每个词(token)我们都通过 embedding 矩阵 把它表示成了向量的形式。

embedding = token_embedding(将单个词转变为词向量) + position_embedding(位置编码,用于表示 token 在输入序列中的位置) + **segment_emebdding(**非必须,在 bert 中用于表示每个词属于哪个句子)。

在 VIT 中,同样存在 token_embedding 和 postion_emebedding。

在 Bert 中,我们同时做 2 个训练任务:

  • Next Sentence Prediction Model(下一句预测):input 中会包含两个句子,这两个句子有 50% 的概率是真实相连的句子,50% 的概率是随机组装在一起的句子。我们在每个 input 前面增加特殊符<cls>,这个位置所在的 token 将会在训练里不断学习整条文本蕴含的信息。最后它将作为 “下一句预测” 任务的输入向量,该任务是一个二分类模型,输出结果表示两个句子是否真实相连。
  • Masked Language Model(遮蔽词猜测):在 input 中,我们会以一定概率随机遮盖掉一些 token(<mask>),以此来强迫模型通过 Bert 中的 attention 结构更好抽取上下文信息,然后在 “遮蔽词猜测” 任务重,准确地将被覆盖的词猜测出来。

就和这次讲的 ViT 一样, BERT也只是重点放在特征的提取,同样也最多适用于分类任务。

ViT model design

image-20250413023931052

image-20250413025715087

标准transformer接受 1D序列的 embedding (实际整体是2D)。为了针对 2D 序列的图片,要把图片扁平为 2D 的patch。将 HxWxC 的图片转换为 Nx(P^2xC) 的patch,注意是把 图片尺寸 HxW 切分成 P^2 的小块。

  • patch embeddings: 之后为了转换为 embedding,将每个patch扁平化,并用linear层映射到embedding,类似于bert的 [cls] token,我们也需要用一个token,来作为最终的整体图片表示,这个 token 最终的标表征用于分类任务,实现起来就是定义一组可学习参数。transformers 里源码是定义 (1,1,hs) 的矩阵,FWD时会在batch维度扩展,再在第二维度(序列维度)上拼接。
1
2
3
4
5
# init
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
# foward
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  • position embedding: 另外将 position embedding添加到 patch embedding中以保留位置信息。我们使用标准可学习的1D位置嵌入,因为我们没有观察到使用更高级的2D感知能position embedding的性能提高,源码里除了正常图片的处理外,还有非正常大小图片的处理,大概就是将原本的 PE 插值化到目标图片的大小。
1
2
3
4
5
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
  • transformer blocks: 之后整体就和正常 transformer 一样了,都是 attention 和MLP的叠加,最后有一个分类head,是由MLP在预训练时间和微调时间的单个线性层的MLP实现的。

通常,我们在大型数据集上预先培训VIT,并对(较小)下游任务进行微调。为此,我们删除了预训练的预测头,并连接一个初始化为0的d×k的FFN,其中k是下游类的数量。比预训练更高的分辨率进行微分解通常是有益的。在喂食较高分辨率的图像时,我们保持patch大小相同,从而导致较大的有效序列长度。ViT可以处理任意序列长度(直至记忆约束),但是,预训练的position embedding可能不再有意义。因此,我们根据原始图像中的位置,对预训练的position embedding进行2D插值。

ViT CNN 混合模型

原文在最后还提出了混合结构,作为原始图像patch的替代方法,可以先从CNN的特征图中形成输入序列。在此混合模型中,patch embeddings 应用于从CNN特征图中提取的patch,相当于替代 linear 层提取特征。

As an alternative to raw image patches, the input sequence can be formed from feature maps of a CNN (LeCun et al., 1989). In this hybrid model, the patch embedding projection E (Eq. 1) is applied to patches extracted from a CNN feature map. As a special case, the patches can have spatial size 1x1, which means that the input sequence is obtained by simply flattening the spatial dimensions of the feature map and projecting to the Transformer dimension. The classification input embedding and position embeddings are added as described above.

虽说 ViT 是作为替代 CNN 的方法,但用 CNN 和这个不矛盾,由于这一步只是输入预处理阶段,和主体模型没有关系。

Transformers ViT 微调

预下载数据集和模型

1
2
3
export HF_ENDPOINT="https://hf-mirror.com"
huggingface-cli download --repo-type dataset --resume-download beans --local-dir ./beans
huggingface-cli download google/vit-base-patch16-224-in21k --local-dir vit-base-patch16-224-in21k

loading dataset - beans

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from datasets import load_dataset
import os

# load cifar10 (only small portion for demonstration purposes)
parquet_path = 'beans/data'

ds = load_dataset('parquet',
data_files={
"train": os.path.join(parquet_path, 'train-*.parquet'),
"validation": os.path.join(parquet_path, 'validation-*.parquet'),
"test": os.path.join(parquet_path, 'test-*.parquet')
}
)
ds
1
2
3
4
5
6
7
8
9
10
11
12
13
14
DatasetDict({
train: Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 1034
})
validation: Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 133
})
test: Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 128
})
})

每个样本包含三个特征:

  • image:一个 PIL 图像对象
  • image_file_path:图像文件的路径,类型为字符串,该路径对应的图像已被加载为 image
  • labels:一个 datasets.ClassLabel 特征,这里会以整数形式表示每个样本的标签
1
2
ex = ds['train'][400]
ex
1
2
3
{'image_file_path': '/home/albert/.cache/huggingface/datasets/downloads/extracted/967f0d9f61a7a8de58892c6fab6f02317c06faf3e19fba6a07b0885a9a7142c7/train/bean_rust/bean_rust_train.148.jpg',
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500>,
'labels': 1}

打印下图片

1
2
image = ex['image']
image

main_7_0

打印出该样本对应的类别标签。可以使用 ClassLabel 提供的 int2str 函数来实现,该函数可以将类别的整数表示转换为对应的字符串标签

1
2
3
labels = ds['train'].features['labels']
print(labels)
print(labels.int2str(ex['labels']))
1
2
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'], id=None)
bean_rust

loading ViT Image Processor

在训练 ViT 模型时,输入图像会经过特定的变换处理。如果对图像应用了错误的变换,模型将无法理解其所看到的内容!

为了确保应用正确的图像变换,我们将使用与预训练模型一同保存的配置来初始化一个 ViTImageProcessor。在本例中,使用的是 google/vit-base-patch16-224-in21k 模型

可以直接打印 ImageProcessor 来显示其 config

1
2
3
4
5
6
from transformers import ViTImageProcessor


model_name_or_path = "../DC/vit-base-patch16-224-in21k"
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
processor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
ViTImageProcessor {
"do_convert_rgb": null,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "ViTImageProcessor",
"image_std": [
0.5,
0.5,
0.5
],
"resample": 2,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 224,
"width": 224
}
}

要处理一张图像,只需将其传递给图像处理器的 __call__ 函数即可。该函数会返回一个字典,其中包含 pixel_values,即图像的数值表示,可直接输入到模型中。

默认情况下返回的是 NumPy 数组,但如果添加参数 return_tensors='pt',则会返回 PyTorch 张量。

1
2
3
ret = processor(image, return_tensors='pt')
print(ret)
print(ret['pixel_values'].shape)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{'pixel_values': tensor([[[[ 0.7882,  0.6706,  0.7098,  ..., -0.1922, -0.1294, -0.1765],
[ 0.7098, 0.6000, 0.6784, ..., -0.2863, -0.1608, -0.1608],
[ 0.4902, 0.3882, 0.4667, ..., -0.1922, -0.0196, 0.0275],
...,
[ 0.3804, 0.5294, 0.4824, ..., -0.8275, -0.8196, -0.8039],
[ 0.0902, 0.3725, 0.3804, ..., -0.8667, -0.8431, -0.8510],
[-0.0510, 0.2784, 0.3176, ..., -0.8588, -0.8275, -0.8353]],

[[ 0.4902, 0.3490, 0.3804, ..., -0.6078, -0.5373, -0.5843],
[ 0.3569, 0.2000, 0.3176, ..., -0.7255, -0.6000, -0.5922],
[ 0.0431, -0.0902, 0.0588, ..., -0.6392, -0.4745, -0.4275],
...,
[-0.2235, -0.0510, -0.0902, ..., -0.9686, -0.9529, -0.9294],
[-0.5059, -0.2078, -0.1922, ..., -0.9922, -0.9922, -1.0000],
[-0.6471, -0.2941, -0.2471, ..., -0.9843, -0.9765, -0.9843]],

[[ 0.4196, 0.2706, 0.3020, ..., -0.7098, -0.6392, -0.6863],
[ 0.2314, 0.0824, 0.2078, ..., -0.8039, -0.6627, -0.6627],
[-0.1137, -0.2314, -0.0824, ..., -0.7020, -0.5373, -0.4980],
...,
[-0.2784, -0.1373, -0.2000, ..., -0.9529, -0.9529, -0.9451],
[-0.6000, -0.3098, -0.3176, ..., -0.9765, -0.9843, -0.9922],
[-0.7569, -0.4118, -0.3804, ..., -0.9765, -0.9686, -0.9686]]]])}
torch.Size([1, 3, 224, 224])

Processing the Dataset

接下来就要批量将 dataset 里的 image 转换为数值,我们直接利用 ViT 的 ImageProcessor。我们定义这个函数,仅返回训练所需的数据,即 1. 数值化的图片 2.分类标签

1
2
3
4
5
6
def process_example(example):
inputs = processor(example['image'], return_tensors='pt')
inputs['labels'] = example['labels']
return inputs

process_example(ds['train'][0])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
{'pixel_values': tensor([[[[-0.5686, -0.5686, -0.5608,  ..., -0.0275,  0.1843, -0.2471],
[-0.6078, -0.6000, -0.5765, ..., -0.0353, -0.0196, -0.2627],
[-0.6314, -0.6314, -0.6078, ..., -0.2314, -0.3647, -0.2235],
...,
[-0.5373, -0.5529, -0.5843, ..., -0.0824, -0.0431, -0.0902],
[-0.5608, -0.5765, -0.5843, ..., 0.3098, 0.1843, 0.1294],
[-0.5843, -0.5922, -0.6078, ..., 0.2627, 0.1608, 0.2000]],

[[-0.7098, -0.7098, -0.7490, ..., -0.3725, -0.1608, -0.6000],
[-0.7333, -0.7333, -0.7569, ..., -0.3647, -0.3255, -0.5686],
[-0.7490, -0.7490, -0.7725, ..., -0.5373, -0.6549, -0.5373],
...,
[-0.7725, -0.7804, -0.8196, ..., -0.2235, -0.0353, 0.0824],
[-0.7961, -0.8118, -0.8118, ..., 0.1922, 0.3098, 0.3725],
[-0.8196, -0.8196, -0.8275, ..., 0.0824, 0.2784, 0.3961]],

[[-0.9922, -0.9922, -1.0000, ..., -0.5451, -0.3569, -0.7255],
[-0.9922, -0.9922, -1.0000, ..., -0.5529, -0.5216, -0.7176],
[-0.9843, -0.9922, -1.0000, ..., -0.6549, -0.7569, -0.6392],
...,
[-0.8431, -0.8588, -0.8980, ..., -0.5765, -0.5529, -0.5451],
[-0.8588, -0.8902, -0.9059, ..., -0.2000, -0.2392, -0.2627],
[-0.8824, -0.9059, -0.9216, ..., -0.2549, -0.2000, -0.1216]]]]), 'labels': 0}

map 返回的数据结构会将 PyTorch tensor 自动转换为 Python list(tolist()),以保证结果是 JSON 可序列化的。

1
2
3
4
5
6
prepared_ds = ds.map(
process_example,
remove_columns=ds['train'].column_names,
batched=False,
desc="Processing dataset"
)
1
ds['train']
1
2
3
4
Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 1034
})
1
prepared_ds['train']
1
2
3
4
Dataset({
features: ['labels', 'pixel_values'],
num_rows: 1034
})

Training and Evaluation

数据已经处理完毕,现在你可以开始搭建训练流程了。本教程使用的是 Hugging Face 的 Trainer,但在此之前我们需要完成以下几项准备工作:

  • 定义一个 collate 函数:用于将一个 batch 的数据整理成模型可以接受的格式。
  • 定义评估指标:在训练过程中,模型需要根据预测准确率进行评估,因此你需要定义一个 compute_metrics 函数来计算这一指标。
  • 加载预训练模型检查点:你需要加载一个预训练的检查点,并正确配置它以便进行训练。
  • 定义训练配置:包括超参数设置、保存策略、日志输出等。

data collator 里再次将数据转换为 tensor,因为 dataset.map 默认还是会把 tensor 改为 list

1
2
3
4
5
6
7
8
import torch

def collate_fn(batch):
return {
'pixel_values': torch.stack([torch.tensor(x['pixel_values']).squeeze(0) for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}

1
collate_fn([prepared_ds['train'][1], prepared_ds['train'][1]])['pixel_values'].shape
1
torch.Size([2, 3, 224, 224])

我们利用 evaluate 的 accuracy 函数来计算分类准确率

1
2
3
4
5
6
7
import numpy as np
from evaluate import load

metric = load("../DC/evaluate/metrics/accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

现在让我们加载预训练模型。在初始化时,我们会传入 num_labels 参数,以便模型构建一个具有正确输出单元数量的分类头。同时,我们还会提供 id2labellabel2id 的映射关系,以便在将模型推送到 Hugging Face Hub 时,能够在界面中显示可读的标签名称。

1
2
3
4
5
6
7
8
9
10
11
12
from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names
print(f'num of labels: {len(labels)}')

model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={i: c for i, c in enumerate(labels)},
label2id={c: i for i, c in enumerate(labels)}
)

1
2
3
num of labels: 3
Some weights of ViTForImageClassification were not initialized from the model checkpoint at ../DC/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import TrainingArguments


training_args = TrainingArguments(
output_dir="./vit-base-beans",
per_device_train_batch_size=16,
eval_strategy="steps",
num_train_epochs=2,
fp16=True,
save_steps=100,
eval_steps=10,
logging_steps=1,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='wandb', # 使用 wandb 进行训练监控
load_best_model_at_end=True,
)

1
2
3
4
5
6
7
8
9
10
11
12
from transformers import Trainer

trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["validation"],
processing_class=processor,
)

1
2
3
4
5
6
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

1
2
3
4
5
6
7
***** train metrics *****
epoch = 2.0
total_flos = 149248978GF
train_loss = 0.201
train_runtime = 0:04:42.28
train_samples_per_second = 7.326
train_steps_per_second = 0.461

image-20250516214044148

1
2
3
4
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

1
2
3
4
5
6
7
***** eval metrics *****
epoch = 2.0
eval_accuracy = 0.9925
eval_loss = 0.0374
eval_runtime = 0:00:09.34
eval_samples_per_second = 14.238
eval_steps_per_second = 1.82

infer

推理很简单了,直接把某个 image 通过 ImageProcessor 处理下图片为 tensor,接着 forward 到 model 即可,得到 logit 后再 argmax 就得到了预测类别。

1
2
ex = ds['test'][0]
ex
1
2
3
{'image_file_path': '/home/albert/.cache/huggingface/datasets/downloads/extracted/807042d188eb9a5d1d9a4179867e5b93eea6ed98d063904065fe40011681df29/test/angular_leaf_spot/angular_leaf_spot_test.0.jpg',
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500>,
'labels': 0}
1
2
3
4
5
6
7
8
9
10
with torch.no_grad():
pixel_values = processor(ex['image'], return_tensors='pt').pixel_values
pixel_values = pixel_values.to(model.device)
outputs = model(pixel_values)
logits = outputs.logits
print(logits.shape)

prediction = logits.argmax(-1)
print("Predicted class index:", prediction.item())
print("Predicted class:", model.config.id2label[prediction.item()])
1
2
3
torch.Size([1, 3])
Predicted class index: 0
Predicted class: angular_leaf_spot
Comments