Qwen 2.5 VL 微调处理

Peng Xia

LVLM 在 transformer 里好像没有统一的训练方法,大概率是因为对模块内部处理方式的不同,不像 LLM 一样都是类似的 GPT 结构,forward 没有 multi-modal,single-modal 的处理比较单调。因此这个 blog 里我参考 Qwen 2.5 VL 中 qwen-vl-finetune 中的处理方式复现其流程,主要是数据处理流程,之后就是简单的 model forward (简单介绍)。

单条数据的处理

这是单挑数据格式,官方支持对话格式是sharegpt的格式,角色和对话字段和名称都和正常用的openai格式不一样,实际代码里还是用openai格式进行判断

1
2
3
4
5
6
7
8
9
10
11
12
13
14
{
"image": ["10149.png", "COCO_train2014_000000580957.jpg"],
"conversations": [
{
"from": "human",
"value": "<image>\nIn which year the value was 51?\n<image>"
},
{
"from": "gpt",
"value": "2014"
}
],
"data_path": "demo/images"
}

为了方便,我接下来直接用 openai 格式的对话模板。

1
2
3
4
5
6
7
8
9
10
11
import os, torch, copy, json
from PIL import Image
from typing import List, Dict
from transformers import AutoProcessor, AutoTokenizer
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor


model_id = '../../../DC/qwen2.5vl-3b-ins'
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
image_processor = processor.image_processor
tokenizer = processor.tokenizer
1
2
3
4
5
6
7
8
9
source  = {
"image": ["10149.png", "COCO_train2014_000000580957.jpg"],
"conversations": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "<image>\nIn which year the value was 51?\n<image>"},
{"role": "assistant", "content": "2014"}
],
"data_path": "demo/images"
}

官方处理韩式是被封装在 qwen2.5-vl/qwen-vl-finetune/qwenvl/data/data_qwen 中的 LazySupervisedDataset 里。
lazy dataset 只在需要时读取单个样本或 batch,适合大规模数据(如图片、视频、大文本)。通常具体处理方式放在 __getitem__ 函数内部。正常 dataset 在 init 时就会一次性加载到内存或完成预处理,不适合特别大的数据集。

因此接下里处理方式里都是处理单条数据,实现的 __getitem__的处理方式.

image 处理

这里是将 image 做预处理

  • 图片调整大小
  • pixel的归一化 (作为 vit 的输入)
  • 图片的维度计算 (作为文本处理的参数之一,需要为 image token 预留位置)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def process_single_image(processor, image_file:str):
# 读取 Image 为 PIL 对象
image = Image.open(image_file).convert("RGB")
# 处理单个图片为tensor
visual_processed = processor.preprocess(image, return_tensors="pt")
image_tensor = visual_processed["pixel_values"]
if isinstance(image_tensor, List):
image_tensor = image_tensor[0]
grid_thw = visual_processed["image_grid_thw"][0]
return image_tensor, grid_thw

if "image" in source:
image_folder = source["data_path"]
image_file = source["image"]
if not isinstance(image_file, List):
image_file = [image_file]

image_file = [
os.path.join(image_folder, file) for file in image_file
]
results = [
process_single_image(image_processor, file)
for file in image_file
]
# grid_thw 是 time height width 的缩写
image_list, grid_thw_list = zip(*results)

# 计算 grid_thw 的点乘 / merge 数量的2次方,即 image 投影后的占用 token 数量
# 由于 qwen 2.5vl 是将相邻的 patch 合并成一个 token,所以需要除以 merge_size 的平方,才得到实际的 token 数量
grid_thw_merged = copy.deepcopy(grid_thw_list)
grid_thw_merged = [
grid_thw.prod() // image_processor.merge_size ** 2
for grid_thw in grid_thw_merged
]

print("image_list:", *[item.shape for item in image_list])
print("grid_thw_list:", *[item.tolist() for item in grid_thw_list])
print("grid_thw_merged:", grid_thw_merged)

预处理结果如下,当然部分值下面需要用到

1
2
3
image_list: torch.Size([440, 1176]) torch.Size([1380, 1176])
grid_thw_list: [1, 20, 22] [1, 30, 46]
grid_thw_merged: [tensor(110), tensor(345)]

image_list 是已经处理为 patch 的数据,vit 是以 patch 为输入的,第一维度则是该图片转换为 patch 的数量,但这样却失去了图片的比例信息,而 grid_thw 参数弥补了这个损失。而 grid_thw_merged 作为图片 token 数量,用于为文本处理预留相关的 pad token。

text 处理

1
2
3
4
5
# tokenizer 重新指定, 不清楚的看 https://huggingface.co/blog/chat-templates
tokenizer_alter = copy.deepcopy(tokenizer)
# 这个提示
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer_alter.chat_template = chat_template

这个格式就是 qwen 语言模型的 chat template格式

1
2
3
4
5
6
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|im_start|>assistant\n' }}
{% endif %}

原来 qwenvl 的 chat template 格式比较复杂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
{# 初始化图像和视频计数器 #}
{% set image_count = namespace(value=0) %}
{% set video_count = namespace(value=0) %}

{# 遍历所有消息 #}
{% for message in messages %}

{# 在第一个消息不是 system 时插入默认 system prompt #}
{% if loop.first and message['role'] != 'system' %}
<|im_start|>system
You are a helpful assistant.
<|im_end|>
{% endif %}

<|im_start|>{{ message['role'] }}

{# 如果消息内容是字符串 #}
{% if message['content'] is string %}
{{ message['content'] }}
<|im_end|>

{# 如果消息内容是一个包含多段内容的列表 #}
{% else %}
{% for content in message['content'] %}

{# 处理图片类型内容 #}
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
{% set image_count.value = image_count.value + 1 %}
{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}
<|vision_start|><|image_pad|><|vision_end|>

{# 处理视频类型内容 #}
{% elif content['type'] == 'video' or 'video' in content %}
{% set video_count.value = video_count.value + 1 %}
{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}
<|vision_start|><|video_pad|><|vision_end|>

{# 处理文本类型内容 #}
{% elif 'text' in content %}
{{ content['text'] }}
{% endif %}

{% endfor %}
<|im_end|>

{% endif %}

{% endfor %}

{# 结束时可选性地加入 assistant 的生成提示 #}
{% if add_generation_prompt %}
<|im_start|>assistant
{% endif %}

接下来处理对话。首先是确保对话合法:

  • 用户消息和大模型消息交叉产生,即配对
  • 用户消息必须是第一个,除非有 system 消息。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
default_system_message = "You are a helpful assistant."

chat_sources = copy.deepcopy(source["conversations"])

r"""
源码处理,只保证第一个是 human 消息,system消息是没法设置,实际第一个是 system 消息,第二个是 human 消息的情况也行
try:
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
except:
print(chat_sources)
"""

# 确保第一个是 human 消息
while not (
(chat_sources[0]['role'] == 'user') or
(len(chat_sources) >=2 and chat_sources[0]['role'] == 'system' and chat_sources[1]['role'] == 'user')
):
chat_sources = chat_sources[1:]

# 确保用户消息和大模型消息交叉产生,即配对
if chat_sources[0]['role'] == 'system':
assert len(chat_sources) % 2 == 1, "user messages and assistant messages must be paired."
for i in range(1, len(chat_sources), 2):
assert (chat_sources[i]['role'] == 'user' and chat_sources[i + 1]['role'] == 'assistant'), "user messages and assistant messages must be paired."
else:
assert len(chat_sources) % 2 == 0, "user messages and assistant messages must be paired."
for i in range(0, len(chat_sources), 2):
assert (chat_sources[i]['role'] == 'user' and chat_sources[i + 1]['role'] == 'assistant'), "user messages and assistant messages must be paired."
chat_sources = [{"role": "system", "content": default_system_message}] + chat_sources # 添加默认 system 消息

print(json.dumps(chat_sources, indent=2, ensure_ascii=False))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "<image>\nIn which year the value was 51?\n<image>"
},
{
"role": "assistant",
"content": "2014"
}
]

这就是一个合法的消息列表。

接下来依次处理每条消息。根据对应消息的角色,来处理对应的 label。注意只有 assistant 角色的 content 需要设置对应 label 为原 token id,而其他都是 -100。

相比于纯文本的处理,主要区别在于需要为 vision 内容预留对应 token 位置,上面 image 处理时已经计算到 每个 image 的 patch 数量,这里就依靠这些数量创建 pad token。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# 这两个变量是用于记录现在处理的图片的index,每个图片的处理就是塞入图片大小相关的 '<|image_pad|>' token。
visual_replicate_index_image = 0
IGNORE_INDEX = -100
input_id, target = [], []

for conv in chat_sources:

role = conv['role']
content = conv['content']

# content 内容只有需要对 user 特殊处理,因为可能有图片和视频相关的内容。
if role == 'user':

if '<image>' in content:
parts = content.split('<image>')
num_parts = len(parts)
new_parts = []
for i in range(num_parts):
# 加入被 <image> 分割的文本部分
new_parts.append(parts[i])
# 加入图片相关的 token,但最后一次不需要
if i != num_parts - 1:
image_tokens = (
"<|vision_start|>"
+ "<|image_pad|>" * grid_thw_merged[visual_replicate_index_image]
+ "<|vision_end|>"
)
visual_replicate_index_image += 1
new_parts.append(image_tokens)

content = "".join(new_parts)

elif '<video>' in content:
# 不做视频相关任务,不写了
pass
else:
pass

text_conv = [{"role": role, "content": content}]
encode_id = tokenizer_alter.apply_chat_template(text_conv) # 一定要用 tokenizer_alter,默认 tokenizer 会加入默认系统消息
input_id += encode_id
if role in ["user", "system"]:
# user 和 system 消息不需要计算损失
target += [IGNORE_INDEX] * len(encode_id)
else:
# assistant 消息需要计算损失,但前缀特殊 token 不需要计算损失
target_mask = encode_id.copy()
target_mask[:3] = [IGNORE_INDEX] * 3 # 忽略开头的 '<|im_start|>system\n'
target += target_mask

assert len(input_id) == len(target), \
f"input_id length ({len(input_id)}) != target length ({len(target)})"
assert visual_replicate_index_image == len(grid_thw_merged), \
f"visual_replicate_index_image ({visual_replicate_index_image}) != len(grid_thw_merged) ({len(grid_thw_merged)})"

input_ids = torch.tensor([input_id], dtype=torch.long)
labels = torch.tensor([target], dtype=torch.long)

print(
tokenizer.decode(input_ids[0]),
tokenizer.decode([i for i in labels[0] if i != IGNORE_INDEX]),
sep='\n'+'-'*10+'\n'
)

以下是 input_ids

1
2
3
4
5
6
7
8
9
10
11
12
|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|>...(共110个<|image_pad|>)<|vision_end|>
In which year the value was 51?
<|vision_start|><|image_pad|>...(共345个<|image_pad|>)<|vision_end|><|im_end|>
<|im_start|>assistant
2014<|im_end|>

----------
2014<|im_end|>
----------

以下是非 -100 的 labels 内容,只有 assistant 的消息和对应的结束 token: ‘2014<|im_end|>’

位置编码

qwen 2.5 vl 的 3d rope 建议还是找专门的博客看下,我这里直接用源码中的 rope 算法,其相比正常 rope 的区别在于为图片和视频创建新的序列维度,所以它的第一维度是 3,图片位置的文本序号会不变,而图片序列会正常增长。

以下 rope2d 来自于 qwen 2.5 vl 的 Qwen2.5-VL/qwen-vl-finetune/qwenvl/data/rope2d.py 。

1
2
3
4
5
6
7
8
9
10
from rope2d import get_rope_index_25


position_ids, _ = get_rope_index_25(
image_processor.merge_size,
input_ids,
image_grid_thw=torch.stack(grid_thw_list, dim=0) if grid_thw_list else None,
video_grid_thw=None,
)

以下的 data_dict 包含以下元素,作为 model.forward 的 inputs

  • input_ids: torch.Size([1, 495])
  • labels: torch.Size([1, 495])
  • position_ids: torch.Size([3, 1, 495])
  • attention_mask: torch.Size([1, 495])
  • pixel_values: torch.Size([1820, 1176])
  • image_grid_thw: torch.Size([2, 3])
1
[item.shape for item in image_list]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
data_dict = {
"input_ids": input_ids,
"labels": labels,
"position_ids": position_ids,
"attention_mask": torch.ones_like(input_ids, dtype=torch.int),
}

if "image" in source:
# image_list 即大小为 [torch.Size([440, 1176]), torch.Size([1380, 1176])] 的 patch 化图片
# 因为都是 patch level 的数据,所以需要将其拼接成一个大的 tensor
data_dict["pixel_values"] = torch.cat(image_list, dim=0)
# 每个图片的 grid_thw 都是一个 tensor 为 [time, height, width],大小一样时 [1,3],将其拼接成一个大的 tensor
data_dict["image_grid_thw"] = torch.cat(
[thw.unsqueeze(0) for thw in grid_thw_list], dim=0
)

for k, v in data_dict.items():
print(f"{k}: {v.shape if isinstance(v, torch.Tensor) else v}")

VL 模型前向传播

Qwen2_5_VLForConditionalGeneration 包含两部分:

  • backbone: Qwen2_5_VLModel
    • visual: Qwen2_5_VisionTransformerPretrainedModel
    • language_model: Qwen2_5_VLTextModel
  • head: nn.Linear

Qwen2_5_VLModel 的 forward 流程如下:

  1. text token embedding
1
2
3
4
5
6
7
8
# 出自 Qwen2_5_VLTextModel
def get_input_embeddings(self):
return self.embed_tokens
# 出自 Qwen2_5_VLModel
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()

inputs_embeds = self.get_input_embeddings()(input_ids)
  1. vision token 填充
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# 出自 Qwen2_5_VisionTransformerPretrainedModel
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
# hidden_states = pixel_values
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())

cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
)

# 2*2 的 patch hidden states 合并
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]

return hidden_states

# 出自 Qwen2_5_VLModel
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
# 统一数据类型
pixel_values = pixel_values.type(self.visual.dtype)
# 获取 patch 的 hidden states,并合并为 image token
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
# 根据 image_grid_thw 计算每个图片的 token 数量
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
# 根据 image_grid_thw 将 image token 切分为每个图片的 image token
image_embeds = torch.split(image_embeds, split_sizes)
return image_embeds

# 获取 image token (按图片切分)
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
# 不知道啥又合并回去了
image_embeds = torch.cat(image_embeds, dim=0)
# self.config.image_token_id 就是 <|image_pad|>,即获取 image padding token 数量
n_image_tokens = (input_ids == self.config.image_token_id).sum()
# 获取 image token 的数据
n_image_features = image_embeds.shape[0]
# 检查实际计算的 image token 与文本中的 padding token 数量是否一致
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
  1. token embedding 合并
1
2
3
4
5
6
7
8
9
# 获取 image padding token 位置的掩码
mask = input_ids == self.config.image_token_id
# 扩展维度,获得 text token embedding 同大小的掩码矩阵
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
# 将 image_embeds 基于掩码 image_mask 映射到 token embedding 上,这样 vision 和 text 的 token embedding 汇合
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
  1. language model 的 forward
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 内部 forward 和 LLM 几乎一致了
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)

output = Qwen2_5_VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
1
2
3
4
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype="auto",
)
1
2
3
4
5
6
output = model(**data_dict)
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f"{k}: {v.shape}")
else:
print(f"{k}: {v}")
1
2
3
loss: torch.Size([])
logits: torch.Size([1, 495, 151936])
past_key_values: <transformers.cache_utils.DynamicCache object at 0x7f668976e1d0>

Lazy Dataset (整合上述处理方式)

数据集预处理

支持多个数据集合并和采样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import re, random, os, torch, copy, json
from PIL import Image
from typing import List
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoProcessor, PreTrainedTokenizer
from rope2d import get_rope_index_25
from dataclasses import dataclass
from typing import Dict, Sequence, List


demo = {
"annotation_path": "./demo/single_images.json",
"data_path": "./demo/images",
}

data_dict = {
"demo": demo,
}

def parse_sampling_rate(dataset_name):
"""
解析数据集名称中的采样率。

该函数从数据集名称字符串中提取以百分号(%)结尾的数字,并将其转换为采样率(小数形式)。
如果数据集名称不包含采样率,则默认返回1.0。
"""
match = re.search(r"%(\d+)$", dataset_name)
if match:
return int(match.group(1)) / 100.0
return 1.0

def get_dataset_config_list(dataset_names):
"""
根据提供的数据集名称列表,生成对应的数据集配置字典列表。
测试用例 1
dataset_names = ["demo"]
输出结果为
[{'annotation_path': './demo/single_images.json', 'data_path': './demo/images', 'sampling_rate': 1.0}]
"""
config_list = []
for dataset_name in dataset_names:
sampling_rate = parse_sampling_rate(dataset_name)
dataset_name = re.sub(r"%(\d+)$", "", dataset_name)
if dataset_name in data_dict.keys():
config = data_dict[dataset_name].copy()
config["sampling_rate"] = sampling_rate
config_list.append(config)
else:
raise ValueError(f"do not find {dataset_name}")
return config_list

def read_jsonl(path):
with open(path, "r") as f:
return [json.loads(line) for line in f]

def load_datasets(dataset_config_list):
list_data_dict = []
for config in dataset_config_list:

annotation_path = config['annotation_path']
data_path = config['data_path']
sampling_rate = config.get('sampling_rate', 1.0)
file_format = annotation_path.split(".")[-1]

if file_format == "jsonl":
annotations = read_jsonl(annotation_path)
else:
annotations = json.load(open(annotation_path, "r"))
if sampling_rate < 1.0:
annotations = random.sample(
annotations, int(len(annotations) * sampling_rate)
)
print(f"sampling {len(annotations)} examples from dataset {config}")
else:
print(f"full loading dataset: {config}")
for ann in annotations:
ann["data_path"] = data_path
list_data_dict += annotations

print(f"Total training samples: {len(list_data_dict)}")
random.shuffle(list_data_dict)
return list_data_dict

dataset_names = ["demo", "demo%80", "demo%50"]
configs = get_dataset_config_list(dataset_names)
print(configs)
[{'annotation_path': './demo/single_images.json', 'data_path': './demo/images', 'sampling_rate': 1.0}, {'annotation_path': './demo/single_images.json', 'data_path': './demo/images', 'sampling_rate': 0.8}, {'annotation_path': './demo/single_images.json', 'data_path': './demo/images', 'sampling_rate': 0.5}]

dataset 和 data collator 定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, dataset_use, image_processor, tokenizer, chat_template, default_system_message):
super(LazySupervisedDataset, self).__init__()
dataset_names = dataset_use.split(",")
dataset_config_list = get_dataset_config_list(dataset_names)
list_data_dict = load_datasets(dataset_config_list)

self.list_data_dict = list_data_dict
self.image_processor = image_processor
self.tokenizer = copy.deepcopy(tokenizer)
self.tokenizer.chat_template = chat_template
self.default_system_message = default_system_message
self.get_rope_index = get_rope_index_25

def __len__(self):
return len(self.list_data_dict)

# 包装的 __getitem__ 方法
def __getitem__(self, idx):
num_base_retries = 3
# 首先尝试获取指定索引的样本,如果失败则重试
for attempt_idx in range(num_base_retries):
try:
sample = self._get_item(idx)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {idx}. Exception:", e)
# 如果获取指定索引的样本失败,则尝试获取下一个样本
for attempt_idx in range(num_base_retries):
try:
next_index = (idx + 1) % len(self.list_data_dict)
# sample_idx = random.choice(range(len(self)))
sample = self._get_item(next_index)
return sample
except Exception as e:
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e,)
# 最后尝试获取第一个样本,不行就抛出异常
try:
sample = self._get_item(idx)
return sample
except Exception as e:
raise e

def process_single_image(self, image_file:str):
# 读取 Image 为 PIL 对象
image = Image.open(image_file).convert("RGB")
# 处理单个图片为tensor
visual_processed = self.image_processor.preprocess(image, return_tensors="pt")
image_tensor = visual_processed["pixel_values"]
if isinstance(image_tensor, List):
image_tensor = image_tensor[0]
grid_thw = visual_processed["image_grid_thw"][0]
return image_tensor, grid_thw

def preprocess_qwen_2_visual(self, chat_sources, image_grid_thw_merged:List = [], video_grid_thw_merged:List = []):
# 确保第一个是 human 消息
while not (
(chat_sources[0]['role'] == 'user') or
(len(chat_sources) >=2 and chat_sources[0]['role'] == 'system' and chat_sources[1]['role'] == 'user')
):
chat_sources = chat_sources[1:]


# 确保用户消息和大模型消息交叉产生,即配对
if chat_sources[0]['role'] == 'system':
assert len(chat_sources) % 2 == 1, "user messages and assistant messages must be paired."
for i in range(1, len(chat_sources), 2):
assert (chat_sources[i]['role'] == 'user' and chat_sources[i + 1]['role'] == 'assistant'), "user messages and assistant messages must be paired."
else:
assert len(chat_sources) % 2 == 0, "user messages and assistant messages must be paired."
for i in range(0, len(chat_sources), 2):
assert (chat_sources[i]['role'] == 'user' and chat_sources[i + 1]['role'] == 'assistant'), "user messages and assistant messages must be paired."
chat_sources = [{"role": "system", "content": self.default_system_message}] + chat_sources # 添加默认 system 消息

# 这两个变量是用于记录现在处理的图片的index,每个图片的处理就是塞入图片大小相关的 '<|image_pad|>' token。
visual_replicate_index_image = 0
IGNORE_INDEX = -100
input_id, target = [], []

for conv in chat_sources:

role = conv['role']
content = conv['content']

# content 内容只有需要对 user 特殊处理,因为可能有图片和视频相关的内容。
if role == 'user':

if '<image>' in content:
parts = content.split('<image>')
num_parts = len(parts)
new_parts = []
for i in range(num_parts):
# 加入被 <image> 分割的文本部分
new_parts.append(parts[i])
# 加入图片相关的 token,但最后一次不需要
if i != num_parts - 1:
image_tokens = (
"<|vision_start|>"
+ "<|image_pad|>" * image_grid_thw_merged[visual_replicate_index_image]
+ "<|vision_end|>"
)
visual_replicate_index_image += 1
new_parts.append(image_tokens)

content = "".join(new_parts)

elif '<video>' in content:
# 不做视频相关任务,不写了
pass
else:
pass

text_conv = [{"role": role, "content": content}]
encode_id = self.tokenizer.apply_chat_template(text_conv)
input_id += encode_id
if role in ["user", "system"]:
# user 和 system 消息不需要计算损失
target += [IGNORE_INDEX] * len(encode_id)
else:
# assistant 消息需要计算损失,但前缀特殊 token 不需要计算损失
target_mask = encode_id.copy()
target_mask[:3] = [IGNORE_INDEX] * 3 # 忽略开头的 '<|im_start|>system\n'
target += target_mask

assert len(input_id) == len(target), \
f"input_id length ({len(input_id)}) != target length ({len(target)})"
assert visual_replicate_index_image == len(image_grid_thw_merged), \
f"visual_replicate_index_image ({visual_replicate_index_image}) != len(image_grid_thw_merged) ({len(image_grid_thw_merged)})"

input_ids = torch.tensor([input_id], dtype=torch.long)
labels = torch.tensor([target], dtype=torch.long)

return dict(
input_ids=input_ids,
labels=labels,
)

def _get_item(self, idx):

source = self.list_data_dict[idx]

image_grid_thw_merged = []
image_grid_thw_list = []
video_grid_thw_merged = []
video_grid_thw_list = []

if "image" in source:
image_folder = source["data_path"]
image_file = source["image"]
if not isinstance(image_file, List):
image_file = [image_file]

image_file = [
os.path.join(image_folder, file) for file in image_file
]
results = [
self.process_single_image(file)
for file in image_file
]
# grid_thw 是 time height width 的缩写
image_list, image_grid_thw_list = zip(*results)

# 计算 grid_thw 的点乘 / merge 数量的2次方,即 image 投影后的占用 token 数量
# 由于 qwen 2.5vl 是将相邻的 patch 合并成一个 token,所以需要除以 merge_size 的平方,才得到实际的 token 数量
image_grid_thw_merged = copy.deepcopy(image_grid_thw_list)
image_grid_thw_merged = [
grid_thw.prod() // image_processor.merge_size ** 2
for grid_thw in image_grid_thw_merged
]

if "video" in source:
raise NotImplementedError("video is not supported yet")

chat_sources = copy.deepcopy(source["conversations"])


if "image" not in source and "video" not in source:

data_dict = self.preprocess_qwen_2_visual(
chat_sources, image_grid_thw_merged=[]
)
position_ids = (
torch.arange(0, data_dict["input_ids"].size(1))
.view(1, -1)
.unsqueeze(0)
.expand(3, -1, -1)
)
else:

data_dict = self.preprocess_qwen_2_visual(
chat_sources,
image_grid_thw_merged=image_grid_thw_merged if "image" in source else [],
video_grid_thw_merged=video_grid_thw_merged if "video" in source else []
)

position_ids, _ = self.get_rope_index(
image_processor.merge_size,
data_dict["input_ids"],
image_grid_thw=torch.stack(image_grid_thw_list, dim=0) if image_grid_thw_list else None,
video_grid_thw=None,
)

data_dict["position_ids"] = position_ids
data_dict["attention_mask"] = torch.ones_like(data_dict["input_ids"], dtype=torch.int)
if "image" in source:
# image_list 即大小为 [torch.Size([440, 1176]), torch.Size([1380, 1176])] 的 patch 化图片
# 因为都是 patch level 的数据,所以需要将其拼接成一个大的 tensor
data_dict["pixel_values"] = torch.cat(image_list, dim=0)
# 每个图片的 grid_thw 都是一个 tensor 为 [time, height, width],大小一样时 [1,3],将其拼接成一个大的 tensor
data_dict["image_grid_thw"] = torch.cat(
[thw.unsqueeze(0) for thw in image_grid_thw_list], dim=0
)
elif "video" in self.list_data_dict[i]:
raise NotImplementedError("video is not supported yet")

return data_dict

def pad_and_cat(tensor_list):
max_length = max(tensor.shape[2] for tensor in tensor_list)

padded_tensors = []
for tensor in tensor_list:
pad_length = max_length - tensor.shape[2]
# 在原 tensor 后面填充 pad_length 个 1
padded_tensor = torch.nn.functional.pad(tensor, (0, pad_length), "constant", 1)
padded_tensors.append(padded_tensor)

stacked_tensor = torch.cat(padded_tensors, dim=1)

return stacked_tensor

@dataclass
class DataCollatorForSupervisedDataset(object):

tokenizer: PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

# input_ids, labels, position_ids, attention_masks 的 padding 处理
input_ids, labels, position_ids, attention_masks = tuple(
[instance[key] for instance in instances]
for key in ("input_ids", "labels", "position_ids", "attention_mask")
)
input_ids = [ids.squeeze(0) for ids in input_ids]
labels = [ids.squeeze(0) for ids in labels]
attention_masks = [ids.squeeze(0) for ids in attention_masks]
# input_ids 和 labels 分别用 padding token 和 ignore index 填充
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = pad_sequence(
labels, batch_first=True, padding_value=-100
)
attention_masks = pad_sequence(
attention_masks, batch_first=True, padding_value=0
)
position_ids = pad_and_cat(position_ids)

# 长度截断
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
position_ids = position_ids[:, : self.tokenizer.model_max_length]
attention_masks = attention_masks[:, : self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
position_ids=position_ids,
attention_mask=attention_masks,
)

# images 不存在长度上的不一致,只是 batch of patch 上会不同,所以在 batch 维度上拼接
images = list(
instance["pixel_values"]
for instance in instances
if "pixel_values" in instance
)
if len(images) != 0:
concat_images = torch.cat([image for image in images], dim=0)
grid_thw = [
instance["image_grid_thw"]
for instance in instances
if "image_grid_thw" in instance
]
grid_thw = torch.cat(grid_thw, dim=0)
else:
concat_images = None
grid_thw = None

videos = list(
instance["pixel_values_videos"]
for instance in instances
if "pixel_values_videos" in instance
)

if len(videos) != 0:
raise NotImplementedError("video is not supported yet")
else:
concat_videos = None
video_grid_thw = None

batch["pixel_values"] = concat_images
batch["image_grid_thw"] = grid_thw
batch["pixel_values_videos"] = concat_videos
batch["video_grid_thw"] = video_grid_thw
return batch

dummy data 测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
model_id = '../../../DC/qwen2.5vl-3b-ins'
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
image_processor = processor.image_processor
tokenizer = processor.tokenizer

config = dict(
dataset_use = "demo",
image_processor = image_processor,
tokenizer = tokenizer,
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
default_system_message = "You are a helpful assistant."
)

dataset = LazySupervisedDataset(**config)

data_collator = DataCollatorForSupervisedDataset(
tokenizer=tokenizer
)
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


full loading dataset: {'annotation_path': './demo/single_images.json', 'data_path': './demo/images', 'sampling_rate': 1.0}
Total training samples: 2
1
2
3
4
5
6
7
8
9
10
11
12
def print_data_dict(data_dict):
for k, v in data_dict.items():
print(f"{k}: {v.shape if isinstance(v, torch.Tensor) else v}")

item_1 = dataset[0]
item_2 = dataset[1]
print_data_dict(item_1)

print(20*'=')

item_batch = data_collator([item_1, item_2])
print_data_dict(item_batch)
input_ids: torch.Size([1, 495])
labels: torch.Size([1, 495])
position_ids: torch.Size([3, 1, 495])
attention_mask: torch.Size([1, 495])
pixel_values: torch.Size([1820, 1176])
image_grid_thw: torch.Size([2, 3])
====================
input_ids: torch.Size([2, 495])
labels: torch.Size([2, 495])
position_ids: torch.Size([3, 2, 495])
attention_mask: torch.Size([2, 495])
pixel_values: torch.Size([2600, 1176])
image_grid_thw: torch.Size([3, 3])
pixel_values_videos: None
video_grid_thw: None

main train loop

VLM 微调和 LLM 微调的另外一个不同之处在于 VLM 是散装的由各个模块组合而成的,因此每个模块都可以独立或者联合训练。一般按照

  • merger: Cross-modal Fusion
  • merger + vit: vision
  • merger + vit + LLM / just LLM: LM
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import Trainer, TrainingArguments

# 设置模型的参数,冻结/解冻视觉模型、merger 和 LLM 模型
def set_model(model_args, model):
if model_args.tune_mm_vision:
for n, p in model.visual.named_parameters():
p.requires_grad = True
else:
for n, p in model.visual.named_parameters():
p.requires_grad = False

if model_args.tune_mm_mlp:
for n, p in model.visual.merger.named_parameters():
p.requires_grad = True
else:
for n, p in model.visual.merger.named_parameters():
p.requires_grad = False

# head 是和 LM 搭配的,需要同时训练/冻结
if model_args.tune_mm_llm:
for n, p in model.model.named_parameters():
p.requires_grad = True
model.lm_head.requires_grad = True
else:
for n, p in model.model.named_parameters():
p.requires_grad = False
model.lm_head.requires_grad = False

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
cache_dir="qwenvl-train-test",
torch_dtype=torch.bfloat16,
)

data_module = dict(
train_dataset = dataset,
data_collator = data_collator,
)

model_args = dict(
# 分别时训练 语言模型、VL merger、视觉模型 的 flag
tune_mm_llm = True,
tune_mm_mlp = True,
tune_mm_vision = True,
)

training_args = TrainingArguments(
learning_rate = 5e-5,
)

trainer = Trainer(
model=model, processing_class=tokenizer, args=training_args, **data_module
)

数据集样例

这是我测试时用的 demo 的文件夹

1
2
3
4
5
├── images
│ ├── 10095.png
│ ├── 10149.png
│ └── COCO_train2014_000000580957.jpg
└── single_images.json

分别是图片文件夹和文本文件。single_images.json 内部信息如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
[
{
"image": "10095.png",
"conversations": [
{
"role": "user",
"content": "Is the value of Favorable 38 in 2015?\n<image>"
},
{
"role": "assistant",
"content": "Yes"
}
]
},
{
"image": ["10149.png", "COCO_train2014_000000580957.jpg"],
"conversations": [
{
"role": "user",
"content": "<image>\nIn which year the value was 51?\n<image>"
},
{
"role": "assistant",
"content": "2014"
}
]
}
]
Comments