CLIP - Contrastive Language-Image Pre-training

Peng Xia

Learning Transferable Visual Models From Natural Language Supervision paper:https://arxiv.org/abs/2103.00020

OpenAI 推出了 CLIP,即 对比式语言-图像预训练(Contrastive Language-Image Pre-training)。简而言之,该模型学习的是整句话与其描述的图像之间的关系;也就是说,当模型训练完成后,给定一段输入文本,它能够检索出与该文本最相关的图像。这里的一个关键点是,它的训练是基于完整句子,而不是单一的类别(例如“汽车”、“狗”等)。其核心直觉在于,使用完整句子进行训练可以让模型学到更多信息,从而在图像与文本之间发现潜在的模式。

CLIP jointly trains an image encoder and a text encoder to predict the correct pairings of a batch of (image, text) training examples.

说直白点,就是让 image embedding 和 (与 image 配对的text)的 text embedding 相似度高,与其他不匹配的 text 的 text embedding 相似度低。我们训练对象是 image encoder 和 text encoder,更准确点,我们会用预训练模型作为 encoder,然后在其后面加入 projection head 来将源模态的 embedding 转为为 unified embedding (统一的模态),encoder 预训练,仅微调,projection head 从头开始训练,这也解决了 embedding 维度不一样的问题。

image-20250511131816188

At test time the learned text encoder synthesizes a zero-shot linear classifier by embedding the names or descriptions of the target dataset’s classes.

在实际 zero-shot 分类时,取各个类别转换为 text,经过 text encoder 转换为各个类别的 text embedding。同时,需要分类的 image 经过 image encoder 转换为 image embedding,计算 image embedding 和 text embedding 之前的相似度,softmax 后最高者对应的类别就是分类结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import albumentations as A

import timm
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer, DistilBertConfig

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

加载 image - caption

Flicker-8k 数据地址:https://www.kaggle.com/datasets/adityajn105/flickr8k

Flicker-8k 将配对数据保存为 txt,此处读取为 dataframe

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
image_path = "./Flicker-8k/Images"
captions_path = "./Flicker-8k/captions.txt"
train_ratio = 0.5
batch_size = 32
num_workers = 4
size = 224


dataframe = pd.read_csv(captions_path)
# print(dataframe[:10].to_markdown())
print('num of data:', len(dataframe))
dataframe = dataframe.sample(frac=1, random_state=42).reset_index(drop=True)
train_size = int(len(dataframe) * train_ratio)
train_dataframe = dataframe[:train_size]
test_dataframe = dataframe[train_size:]
image caption
0 1000268201_693b08cb0e.jpg A child in a pink dress is climbing up a set of stairs in an entry way .
1 1000268201_693b08cb0e.jpg A girl going into a wooden building .
2 1000268201_693b08cb0e.jpg A little girl climbing into a wooden playhouse .
3 1000268201_693b08cb0e.jpg A little girl climbing the stairs to her playhouse .
4 1000268201_693b08cb0e.jpg A little girl in a pink dress going into a wooden cabin .
5 1001773457_577c3a7d70.jpg A black dog and a spotted dog are fighting
6 1001773457_577c3a7d70.jpg A black dog and a tri-colored dog playing with each other on the road .
7 1001773457_577c3a7d70.jpg A black dog and a white dog with brown spots are staring at each other in the street .
8 1001773457_577c3a7d70.jpg Two dogs of different breeds looking at each other on the road .
9 1001773457_577c3a7d70.jpg Two dogs on pavement moving toward each other .

每个图片都有 5 个 caption 来描述其内容。

1
2
tokenizer = AutoTokenizer.from_pretrained("/home/guest/others/data_collection/distilbert-base-uncased")
print(tokenizer("hello")) # {'input_ids': [101, 7592, 102], 'attention_mask': [1, 1, 1]}

数据集设置很简单,相比于有监督学习的 x 和 y,对比学习是正样本和负样本。以上每一行的 image 和 caption 就是正样本,而负样本是同一 batch 的其他 caption。

虽然严格意义上,我们可能遇到同一 batch 里的两个 image 是一样,但其对应 caption 是语义上类似,但不严格相同的。这种情况下还是会被视为负样本,但考虑到这样的情况很少出现,即很小概率正样本被当作负样本处理,所以整体训练上是没毛病的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def get_transforms(mode="train"):
if mode == "train":
return A.Compose([
A.Resize(size, size),
A.Normalize(max_pixel_value=255.0),
])
else:
return A.Compose([
A.Resize(size, size),
A.Normalize(max_pixel_value=255.0),
])


class CLIPDataset(torch.utils.data.Dataset):

def __init__(self, image_filenames, captions, tokenizer, transforms):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""

self.image_filenames = image_filenames
self.captions = list(captions)
self.encoded_captions = tokenizer(
list(captions), padding=True, truncation=True, max_length=200
)
self.transforms = transforms

# 每个数据是一个 image 和一个 caption ,准确点是一个 tensor 化的 image 和一个 tokenized 的 caption
def __getitem__(self, idx):
item = {
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
}

image = cv2.imread(f"{image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]

return item


def __len__(self):
return len(self.captions)


def build_loaders(dataframe, tokenizer, mode, batch_size=32, num_workers=4):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader


train_loader = build_loaders(train_dataframe, tokenizer, mode="train", batch_size=batch_size)
valid_loader = build_loaders(train_dataframe, tokenizer, mode="valid", batch_size=batch_size)
# print(train_loader.dataset.__getitem__(0))
# print(next(iter(train_loader)))
1
2
3
a_batch = next(iter(train_loader))
for k, v in a_batch.items():
print(f"key name: {k}, value data type: {type(v)}, value shape: {v.shape if isinstance(v, torch.Tensor) else len(v)}")
1
2
3
4
key name: input_ids, value data type: <class 'torch.Tensor'>, value shape: torch.Size([32, 42])
key name: attention_mask, value data type: <class 'torch.Tensor'>, value shape: torch.Size([32, 42])
key name: image, value data type: <class 'torch.Tensor'>, value shape: torch.Size([32, 3, 224, 224])
key name: caption, value data type: <class 'list'>, value shape: 32

模型

  • image encoder 用 resnet 或者 vit,forward不变,返回图片 embedding
  • text encoder 用 bert 变体 distilbert,foward返回 CLS token 位置的 embedding
  • 两个 projection head 分别把 image embedding 和 text embedding 转换为统一的 embedding
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""

def __init__(
self, model_name='resnet50', pretrained=True, trainable=True, pretrained_cfg_overlay=None
):
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg", pretrained_cfg_overlay=pretrained_cfg_overlay
)
for p in self.model.parameters():
p.requires_grad = trainable

def forward(self, x):
return self.model(x)


class TextEncoder(nn.Module):

def __init__(self, model_name="distilbert-base-uncased", pretrained=True, trainable=True):
super().__init__()
if pretrained:
self.model = AutoModel.from_pretrained(model_name)
else:
self.model = AutoModel(config=DistilBertConfig())

for p in self.model.parameters():
p.requires_grad = trainable

# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0

def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
# (batch_size, sequence_length, hidden_size)
last_hidden_state = output.last_hidden_state
# use CLS token hidden representation
# return (batch_size, hidden_size)
return last_hidden_state[:, self.target_token_idx, :]


class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=256,
dropout=0.1
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)

def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x


class CLIPModel(nn.Module):

def __init__(
self,
image_encoder,
text_encoder,
image_embedding=2048,
text_embedding=768,
temperature=1.0,
):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature

def forward(self, images:torch.tensor, input_ids:torch.tensor, attention_masks:torch.tensor):
# Getting Image and Text Features
image_features = self.image_encoder(images)
text_features = self.text_encoder(
input_ids=input_ids, attention_mask=attention_masks
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)

# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature

return logits, image_embeddings, text_embeddings


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
image_encoder = ImageEncoder(
model_name='resnet50',
pretrained_cfg_overlay=dict(file='/home/guest/others/xp/clip/timm-resnet50/pytorch_model.bin'),
)

text_encoder = TextEncoder(
model_name="/home/guest/others/data_collection/distilbert-base-uncased"
)

model = CLIPModel(
image_encoder=image_encoder,
text_encoder=text_encoder,
image_embedding=2048,
text_embedding=768,
temperature=1.0,
)

训练函数

loss 有两种写法

  1. CLIP 原文用的就是对角矩阵为 label 的情况,只要分别在 dim0 和 dim1 与 arange(n) 序列作交叉熵计算就行。原文就是这样实现的,以下是论文中的伪代码。

image-20250511131223534

Given a batch of N (image, text) pairs, CLIP is trained to predict which of the N × N possible (image, text) pairings across a batch actually occurred. To do this, CLIP learns a with high pointwise mutual information as well as the names of all Wikipedia articles above a certain search volume. Finally all WordNet synsets not already in the query list are added. multi-modal embedding space by jointly training an image encoder and text encoder to maximize the cosine similarity of the image and text embeddings of the N real pairs in the batch while minimizing the cosine similarity of the embeddings of the N^2 − N incorrect pairings.

1
2
3
4
5
6
batch_size = image_embeddings.shape[0]
labels = torch.arange(batch_size).to(device)
texts_loss = self.criterion(logits, labels)
images_loss = self.criterion(logits.T, labels)

loss = (images_loss + texts_loss) / 2.0
  1. 还有一种是以 image-image 和 text-text 相似度为 ground truth,对角线肯定相似度很高,另外部分非对角线数据也会根据其实际 image-image 和 text-text 的相似度。这种方式我感觉可以弥补正样本被视为负样本的情况。
1
2
3
4
5
6
7
8
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = nn.CrossEntropyLoss()(logits, targets)
images_loss = nn.CrossEntropyLoss()(logits.T, targets.T)
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def train_batch(model, optimizer, batch, device):

optimizer.zero_grad()
criterion = nn.CrossEntropyLoss()

images = batch["image"].to(device)
input_ids = batch["input_ids"].to(device)
attention_masks = batch["attention_mask"].to(device)

# image_embeddings and text_embeddings have same batch size
# logits is a matrix of shape (batch_size, batch_size)
# logits[i][j] is the similarity score between image i and text j
logits, image_embeddings, text_embeddings = model(
images=images,
input_ids=input_ids,
attention_masks=attention_masks
)

batch_size = image_embeddings.shape[0]
labels = torch.arange(batch_size).to(device)
texts_loss = criterion(logits, labels)
images_loss = criterion(logits.T, labels)
loss = (images_loss + texts_loss) / 2.0

loss.backward()
optimizer.step()

return loss.item()


def train_epoch(model, optimizer, dataloader, device):

model.train()

losses = []

for batch in tqdm(dataloader):

batch_loss = train_batch(model, optimizer, batch, device)
losses.append(batch_loss)

ave_loss = np.mean(losses)
return ave_loss

def test_batch(model, batch, device):


images = batch["image"].to(device)
input_ids = batch["input_ids"].to(device)
attention_masks = batch["attention_mask"].to(device)

logits, image_embeddings, text_embeddings = model(
images=images,
input_ids=input_ids,
attention_masks=attention_masks
)

gt = torch.arange(logits.size(0), device=logits.device)
pred = torch.argmax(logits, dim=1)
correct = (pred == gt).sum().item()

return correct / logits.size(0)


def test_epoch(model, dataloader, device):

model.eval()

accs = []

with torch.no_grad():

for batch in tqdm(dataloader):

batch_acc = test_batch(model, batch, device)
accs.append(batch_acc)

ave_acc = np.mean(accs)
return ave_acc

训练流程

这里考虑到 image encoder 和 text encoder 都是已经预训练的,而 projection head 是从 scatch 开始训的,所以最好设置 encoder 的学习率较低, projection head 的学习率为正常值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
n_epochs = 20
device = torch.device("cuda:1")

model = model.to(device)


# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
optimizer = torch.optim.AdamW([
{"params": model.image_encoder.parameters(), "lr": 1e-5},
{"params": model.text_encoder.parameters(), "lr": 1e-5},
{"params": model.image_projection.parameters(), "lr": 1e-3},
{"params": model.text_projection.parameters(), "lr": 1e-3},
])
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=5, gamma=0.5,
)

for epoch in range(20):

print(f"Epoch: {epoch + 1}")
print(f'Learning Rate: {lr_scheduler.get_last_lr()[0]}')

train_loss = train_epoch(model, optimizer, train_loader, device)
print(f"Train Loss: {train_loss:.4f}")

test_acc = test_epoch(model, valid_loader, device)
print(f"Test Accuracy: {test_acc:.4f}")

lr_scheduler.step()
# break

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
Epoch: 1
Learning Rate: 1e-05
100%|██████████| 633/633 [02:00<00:00, 5.27it/s]
Train Loss: 2.2545
100%|██████████| 633/633 [00:50<00:00, 12.64it/s]
Test Accuracy: 0.7001
Epoch: 2
Learning Rate: 1e-05
100%|██████████| 633/633 [02:07<00:00, 4.98it/s]
Train Loss: 0.7770
100%|██████████| 633/633 [01:08<00:00, 9.23it/s]
Test Accuracy: 0.8489
Epoch: 3
Learning Rate: 1e-05
100%|██████████| 633/633 [02:18<00:00, 4.58it/s]
Train Loss: 0.4705
100%|██████████| 633/633 [01:04<00:00, 9.77it/s]
Test Accuracy: 0.9080
Epoch: 4
Learning Rate: 1e-05
100%|██████████| 633/633 [01:53<00:00, 5.56it/s]
Train Loss: 0.3176
100%|██████████| 633/633 [00:50<00:00, 12.51it/s]
Test Accuracy: 0.9228
Epoch: 5
Learning Rate: 1e-05
100%|██████████| 633/633 [01:57<00:00, 5.40it/s]
Train Loss: 0.2462
100%|██████████| 633/633 [00:51<00:00, 12.20it/s]
Test Accuracy: 0.9343
Comments
On this page
CLIP - Contrastive Language-Image Pre-training