白盒模型评估 SFT 数据质量

Peng Xia

白盒大模型最基础的指标就是困惑度 perplexity (PPL)。低 PPL 表示模型对输出序列 y 的概率分布预测更精确,模型对数据的“困惑”更低。高 PPL 表示模型对输出序列 y 的概率分布预测不准确,困惑程度较高。同时 PPL 是长度归一化的,可以避免直接受到长度的影响。

给定输入序列 x ,白盒大模型输出序列 y 的 PPL 的计算公式如下:

类似于指令微调,x 是给定的,无需大模型主观生产,而 y 是大模型基于 x 主动生成的,PPL 是关于 y 部分的。

以下是用 PyTorch 和 Hugging Face 计算 PPL 的示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载模型和分词器
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 定义输入 x 和输出 y
instruction = "Summarize the following paragraph."
input_text = "The quick brown fox jumps over the lazy dog."
output_text = "The fox jumps over the dog."
x = instruction + " " + input_text
y = output_text

# 将 x 和 y 连接为完整序列
full_text = x + " " + y
inputs = tokenizer(full_text, return_tensors="pt")

# 获取 tokenized 输入
input_ids = inputs["input_ids"]
labels = input_ids.clone()

# 计算模型输出
with torch.no_grad():
outputs = model(input_ids, labels=labels)
loss = outputs.loss # CrossEntropyLoss

# 根据 loss 计算 PPL
ppl = torch.exp(loss)
print(f"PPL: {ppl.item()}")
1
2
3
4
import json
with open('./data/alpaca_gpt4_data_zh_10.json', 'r', encoding='utf-8') as file:
triplet_list = json.load(file)
triplet_list = triplet_list[:5]
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = '../../DataCollection/officials/Qwen2.5-3B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
)
print('1')
model = model.to('cuda:6')
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


1

假设数据还是三元组的形式,三元组的拼接方式大部分都不会对计算困惑度有啥大影响,除非拼接方式特别离谱。反正这里我们只是筛选数据,并不是改为SFT形式的数据,所以不对三元组的拼接方式作特殊处理。

prompt 拼接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
PROMPT_DICT = {
"zero-shot": {
"prompt_input": (
"Below is a description of the task and an input providing more context. "
"Please write an appropriate response to complete the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is a description of the task. "
"Please write an appropriate response to complete the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
},
"one-shot": {
"prompt_input": (
"Below is a description of the task and an input providing more context. "
"Please write an appropriate response to complete the request based on the example task and new task.\n\n"
"### Instruction:\n{instruction}\n\n"
"### Example Task\n"
"#### Input:\n{example_input}\n\n#### Response:\n{example_output}\n\n"
"### New Task\n"
"#### Input:\n{input}\n\n#### Response:"
),
"prompt_no_input": (
"Below is a description of the task and a reference example. "
"Please write an appropriate response to complete the new task based on the reference example.\n\n"
"### Instruction:\n{instruction}\n\n"
"### Example Task\n"
"#### Response:\n{example_output}\n\n"
"### New Task\n"
"#### Instruction:\n{instruction}\n\n#### Response:"
),
},
"reversed-zero-shot":{
"prompt_input": (
"Below is an input provided for a certain instruction, and a response for completing that instruction. "
"Please generate an appropriate instruction based on the output and response.\n\n"
"### Input:\n{input}\n\n### Response:\n{output}\n\n### Instruction:"
),
"prompt_no_input": (
"Below is a response for completing a certain instruction. "
"Please generate an appropriate instruction based on the output. \n\n"
"### Response:\n{output}\n\n### Instruction:"
),
}
}

def create_zero_shot_prompt(triplet):
"""
产生 zero-shot learning 的 prompt
triplet 应该是
{
"instruction": "...",
"input": "...",
"output": "..."
}
"""
prompt_template = PROMPT_DICT["zero-shot"]["prompt_input" if triplet["input"] != "" else "prompt_no_input"]
prompt = prompt_template.format_map(triplet)
return prompt


def create_one_shot_prompt(triplet, example_io):
"""
产生 one-shot learning 的 prompt
triplet 应该是
{
"instruction": "...",
"input": "...",
"output": "..."
}
example_io 应该是
{
"example_input": "...",
"example_output": "..."
}
虽然这种prompt应该只用于有input的情况, 但也可以用于没有input的情况
"""
prompt_template = PROMPT_DICT["one-shot"]["prompt_input" if triplet["input"] != "" else "prompt_no_input"]
prompt = prompt_template.format_map({**triplet, **example_io})
return prompt


def create_reverse_prompt(triplet):
prompt_template = PROMPT_DICT["reversed-zero-shot"]["prompt_input" if triplet["input"] != "" else "prompt_no_input"]
prompt = prompt_template.format_map(triplet)
return prompt



print(create_zero_shot_prompt(triplet_list[0]))
Below is a description of the task. Please write an appropriate response to complete the request.

### Instruction:
保持健康的三个提示。

### Response:

计算困惑度

这个文章参考了IFD的原始代码,一个问题就是它原来是每次计算一条数据的困惑度,用的是transformers模型forward函数自带的计算loss方法(进而转换为困惑度),只需要把input_ids和labels正常传入即可。

1
2
3
4
5
6
7
8
input_ids = tokenizer(prompt, max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')["input_ids"].to(model.device)
input_data = {
"input_ids": input_ids,
"labels": input_ids
}
model_ret = model(**input_data)
loss = model_ret.loss
ppl = torch.exp(loss).item()

这种方式计算起来太慢了,没能重复利用gpu优势,所以这部分代码我改为批量计算困惑度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq

def reorder_arrays(sort_order, *arrays, reverse=False):
"""
根据排序数组的顺序重新排序多个数组,并支持控制反序。
"""
# 获取排序后的索引,考虑 reverse 参数
sorted_indices = sorted(range(len(sort_order)), key=lambda x: sort_order[x], reverse=reverse)

# 按索引重新排序每个数组
reordered_arrays = tuple([array[i] for i in sorted_indices] for array in arrays)

return reordered_arrays

def calculate_sample_perplexity(model, tokenizer, input_data, batch_size=5):
"""批量计算困惑度"""
data_collator = DataCollatorForSeq2Seq(tokenizer)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
model.eval()

data_perplexities = []
with tqdm(total=len(input_data)) as pbar:
for i in range(0, len(input_data), batch_size):

batch_data = input_data[i:i + batch_size]
batch_data_tensor = data_collator(batch_data).to(model.device)

model_outputs = model(
input_ids=batch_data_tensor['input_ids'],
attention_mask=batch_data_tensor['attention_mask']
)
logits = model_outputs.logits

shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch_data_tensor['labels'][:, 1:].contiguous()

per_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
per_token_loss = per_token_loss.view(shift_labels.size())
label_valid_mask = (shift_labels != -100)
per_sample_loss = (per_token_loss * label_valid_mask).sum(dim=1) / label_valid_mask.sum(dim=1)

batch_perplexities = torch.exp(per_sample_loss).tolist()
data_perplexities.extend(batch_perplexities)
pbar.update(len(batch_data))
return data_perplexities

def calculate_sample_perplexity_resortable(model, tokenizer, input_data, batch_size=5, sort_by_len:bool=False):
"""批量计算困惑度, but可以重排序以减少padding"""

if sort_by_len:
input_data_len = [len(item["input_ids"]) for item in input_data]
input_data_indice = list(range(len(input_data)))
input_data, input_data_indice = reorder_arrays(input_data_len, input_data, input_data_indice)

data_collator = DataCollatorForSeq2Seq(tokenizer)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
model.eval()

data_perplexities = []
with tqdm(total=len(input_data)) as pbar:
for i in range(0, len(input_data), batch_size):

batch_data = input_data[i:i + batch_size]
batch_data_tensor = data_collator(batch_data).to(model.device)

model_outputs = model(
input_ids=batch_data_tensor['input_ids'],
attention_mask=batch_data_tensor['attention_mask']
)
logits = model_outputs.logits

shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch_data_tensor['labels'][:, 1:].contiguous()

per_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
per_token_loss = per_token_loss.view(shift_labels.size())
label_valid_mask = (shift_labels != -100)
per_sample_loss = (per_token_loss * label_valid_mask).sum(dim=1) / label_valid_mask.sum(dim=1)

batch_perplexities = torch.exp(per_sample_loss).tolist()
data_perplexities.extend(batch_perplexities)
pbar.update(len(batch_data))

if sort_by_len:
input_data_perplexities, = reorder_arrays(input_data_indice, input_data_perplexities)

return data_perplexities

困惑度类指标

Prompt 困惑度

Prompt = Instruction + Input。如果指令写的很清晰的话,大模型理解指令这个 prompt 文本就很容易,理应更不困惑,所以就可以用 PPL(x) 来衡量大模型对 prompt 文本的理解清晰程度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def get_prompt_complexity(triplet_list, model, tokenizer, batch_size:int=10):
# prepare data
input_data = []
for triplet in triplet_list:
prompt = create_zero_shot_prompt(triplet)
input_ids = tokenizer.encode(prompt, max_length=tokenizer.model_max_length, truncation=True)
input_data.append(
{
"input_ids": input_ids,
"labels": input_ids
}
)

# calculate perplexity
data_perplexities = calculate_sample_perplexity_resortable(model, tokenizer, input_data, batch_size)
return data_perplexities

get_prompt_complexity(triplet_list, model, tokenizer)
100%|██████████| 5/5 [00:01<00:00,  3.26it/s]





[185.0, 123.0, 115.5, 139.0, 79.5]

Response 困惑度

一方面好的 prompt 更易于让模型输出对应的 output,另一方面好的 output 也更容易在给定 prompt 的情况下生成。所以计算 PPL(y|x) 对 prompt 和 output 都有一定衡量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from copy import deepcopy

def get_response_complexity(triplet_list, model, tokenizer, batch_size:int=10):
# prepare data
input_data = []
for triplet in triplet_list:

output = triplet["output"]
input_ids = tokenizer.encode(output, max_length=tokenizer.model_max_length, truncation=True)

prompt = create_zero_shot_prompt(triplet)
whole_text = prompt + output
input_ids = tokenizer.encode(whole_text, max_length=tokenizer.model_max_length, truncation=True)
token_start_index = len(tokenizer.encode(prompt, max_length=tokenizer.model_max_length, truncation=True))
labels = deepcopy(input_ids)
labels[:token_start_index] = [-100] * min(token_start_index, len(labels))
input_data.append(
{
"input_ids": input_ids,
"labels": labels
}
)
# calculate perplexity
data_perplexities = calculate_sample_perplexity_resortable(model, tokenizer, input_data, batch_size)
return data_perplexities

get_response_complexity(triplet_list, model, tokenizer)
100%|██████████| 5/5 [00:00<00:00, 33.38it/s]





[3.921875, 5.90625, 5.03125, 4.59375, 12.0]

指令跟随难度 Instruction Following Difficulty (IFD)

上述的 Output 困惑度虽然包含了 prompt 对生成 output 的帮助程度,但同时也和 output 本身生成也有一定关系,因此Output 困惑度包含了两方面:

  • prompt 对生成 output 的帮助程度
  • output 本身生成的容易程度

由于后者的存在,该困惑度与 output 本身也耦合,例如更长更复杂的 output 的 PPL 相比于短又直接的 output 更大。因此有必要将 “output 本身生成的容易程度” 剥离,从而仅保留 “prompt 对生成 output 的帮助程度”。而 “output 本身生成的容易程度” 本质上可以由 output 本身的 PPL 代表,因此 “prompt 对生成 output 的帮助程度” 可以表达为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def create_IFD_input_data(triplet_list, tokenizer):
data_whole_text = []
data_output_only = []
for triplet in triplet_list:

output = triplet["output"]
input_ids = tokenizer.encode(output, max_length=tokenizer.model_max_length, truncation=True)
data_output_only.append(
{
"input_ids": input_ids,
"labels": input_ids
}
)

prompt = create_zero_shot_prompt(triplet)
whole_text = prompt + output
input_ids = tokenizer.encode(whole_text, max_length=tokenizer.model_max_length, truncation=True)
token_start_index = len(tokenizer.encode(prompt, max_length=tokenizer.model_max_length, truncation=True))
labels = deepcopy(input_ids)
labels[:token_start_index] = [-100] * min(token_start_index, len(labels))
data_whole_text.append(
{
"input_ids": input_ids,
"labels": labels
}
)
return data_whole_text, data_output_only

def get_IFD(triplet_list, model, tokenizer, batch_size:int=10):
# prepare data
data_whole_text, data_output_only = create_IFD_input_data(triplet_list, tokenizer)

# calculate perplexity
ppls = []
for input_data in [data_whole_text, data_output_only]:

input_data_perplexities = calculate_sample_perplexity_resortable(model, tokenizer, input_data, batch_size)

ppls.append(input_data_perplexities)

# ppl_whole_text, ppl_output_only = ppls
# IFD = PPL(y|x) / PPL(y), x = prompt ~= instruction + input, y = output
ifd = [a/b for a, b in zip(*ppls)]

return ifd

get_IFD(triplet_list, model, tokenizer)
100%|██████████| 5/5 [00:00<00:00, 70.08it/s]
100%|██████████| 5/5 [00:00<00:00, 69.38it/s]





[0.9507575757575758,
 1.0617977528089888,
 1.0125786163522013,
 1.027972027972028,
 0.9411764705882353]

IFD 应该越小越好

指令生成难度 Instruction Generate Difficulty

指令跟随难度评估对象是指令,同理需要评估回复。类似于指令有助于生成回复,回复反向能有助于生成对应的指令。现在将 output 和 instruction 的位置翻转,任务变成基于回复生成指令。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def create_IGD_input_data(triplet_list, tokenizer):
data_whole_text = []
data_output_only = []
for triplet in triplet_list:

output = triplet["instruction"]
input_ids = tokenizer.encode(output, max_length=tokenizer.model_max_length, truncation=True)
data_output_only.append(
{
"input_ids": input_ids,
"labels": input_ids
}
)

prompt = create_reverse_prompt(triplet)
whole_text = prompt + output
input_ids = tokenizer.encode(whole_text, max_length=tokenizer.model_max_length, truncation=True)
token_start_index = len(tokenizer.encode(prompt, max_length=tokenizer.model_max_length, truncation=True))
labels = deepcopy(input_ids)
labels[:token_start_index] = [-100] * min(token_start_index, len(labels))
data_whole_text.append(
{
"input_ids": input_ids,
"labels": labels
}
)
return data_whole_text, data_output_only

def get_IGD(triplet_list, model, tokenizer, batch_size:int=10):
# prepare data
data_whole_text, data_output_only = create_IGD_input_data(triplet_list, tokenizer)

# calculate perplexity
ppls = []
for input_data in [data_whole_text, data_output_only]:

input_data_perplexities = calculate_sample_perplexity_resortable(model, tokenizer, input_data, batch_size)

ppls.append(input_data_perplexities)

# ppl_whole_text, ppl_output_only = ppls
# IFD = PPL(y|x) / PPL(y), x = prompt ~= instruction + input, y = output
igd = [a/b for a, b in zip(*ppls)]

return igd


get_IGD(triplet_list, model, tokenizer)
  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 68.98it/s]
100%|██████████| 5/5 [00:00<00:00, 72.23it/s]





[0.4421768707482993,
 1.0,
 0.6465116279069767,
 0.5043103448275862,
 0.8049792531120332]

One-shot 有效性

众所周知,当无法微调模型时,如果指令不够清晰或者比较复杂,可以通过加入输入输出例子来让模型通过 Few-shot learning / In-context learning 来学习指令怎么执行。因此这些 shot / 例子可以优化大模型的理解,如果 shot 是合理的情况。反之,如果 shot 不合理,那甚至效果不如没有 shot 的情况。因此通过对比一组数据作为 shot 辅助其他数据生成的有效性,就可知作为 shot 的数据的合理性/有效性。

One-shot 有效性可计算为:以需评估的数据为 shot, 将该 shot 以 one-shot learning 的方法对生成其他数据的影响。

注意:

  • 计算时需要随机抽取以计算平均值
  • 由于 PPL 会与对应的 output 相关,最好只取符号值,即二元判断有无帮助
  • 由于 Few-shot learning 仅在同种任务的数据上有用,所以有必要先根据任务类型再计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import random
import numpy as np
from typing import Literal
from collections import defaultdict

def list_to_index_dict(input_list):
"""
将一个列表转换为字典,其中键为列表的元素值,值为该元素的索引数组。
"""
index_dict = {}
for idx, value in enumerate(input_list):
index_dict.setdefault(value, []).append(idx)
return index_dict

def get_random_non_m_values(indice_pool, m, n):
"""
从 indice_pool 中随机抽取 n 个非 m 的不重复值。如果数量不够, 直接返回所有非 m 的值
"""
# 过滤掉值为 m 的元素
filtered_pool = [x for x in indice_pool if x != m]

# 检查池中是否有足够的元素
if len(filtered_pool) < n:
return filtered_pool

# 从过滤后的池中随机抽取 n 个不重复的元素
random_selection = random.sample(filtered_pool, n)

return random_selection

def create_zero_and_one_shot_input_data(triplet, example, tokenizer):
"""
分别创造 zero-shot 和 one-shot 时的 input_data
注意, example 是评估对象/数据, triplet 是随机抽取的同任务数据
"""
output = triplet["output"]
prompt_zero_shot = create_zero_shot_prompt(triplet)
whole_text_zero_shot = prompt_zero_shot + output
input_ids = tokenizer.encode(whole_text_zero_shot, max_length=tokenizer.model_max_length, truncation=True)
token_start_index = len(tokenizer.encode(prompt_zero_shot, max_length=tokenizer.model_max_length, truncation=True))
labels = deepcopy(input_ids)
labels[:token_start_index] = [-100] * min(token_start_index, len(labels))

data_zero_shot = {
"input_ids": input_ids,
"labels": labels
}

prompt_one_shot = create_one_shot_prompt(triplet, example)
whole_text_one_shot = prompt_one_shot + output

input_ids = tokenizer.encode(whole_text_one_shot, max_length=tokenizer.model_max_length, truncation=True)
token_start_index = len(tokenizer.encode(prompt_one_shot, max_length=tokenizer.model_max_length, truncation=True))
labels = deepcopy(input_ids)
labels[:token_start_index] = [-100] * min(token_start_index, len(labels))

data_one_shot = {
"input_ids": input_ids,
"labels": labels
}

return data_zero_shot, data_one_shot


def get_One_Shot_Example_Validity(triplet_list, triplet_task_list,
model, tokenizer,
one_shot_sample_cnt:int=3, validity_calculation:Literal['raw', 'sign']='sign',
batch_size=10):

task2indexs = list_to_index_dict(triplet_task_list)

input_data_zero_shot = []
input_data_one_shot = []
input_data_indices = []


for indice, (triplet, task) in tqdm(enumerate(zip(triplet_list, triplet_task_list))):

indice_pool = get_random_non_m_values(task2indexs[task], indice, one_shot_sample_cnt)

example = {
"example_input": triplet["input"],
"example_output": triplet["output"]
}
for tmp in indice_pool:
data_zero_shot, data_one_shot = create_zero_and_one_shot_input_data(
triplet=triplet_list[tmp],
example=example,
tokenizer=tokenizer
)
input_data_zero_shot.append(data_zero_shot)
input_data_one_shot.append(data_one_shot)
input_data_indices.append(indice)

ppls = []
for input_data in [input_data_zero_shot, input_data_one_shot]:

data_perplexities = calculate_sample_perplexity_resortable(model, tokenizer, input_data, batch_size)
ppls.append(data_perplexities)

# 计算 one-shot 相比 zero-shot 时的有效性
validity = []
for a, b in zip(*ppls):
tmp = a - b
if validity_calculation == "raw":
pass
elif validity_calculation == "sign":
tmp = np.sign(tmp)
else:
pass
validity.append(tmp)

# 根据 indice 计算平均值, 没有则为nan
index_dict = defaultdict(list)
for idx, val in zip(input_data_indices, validity):
index_dict[idx].append(val)

result = []
for i in range(len(triplet_list)):
if i in index_dict:
avg = np.mean(index_dict[i])
else:
avg = np.nan # 如果没有对应的值,返回 NaN
result.append(avg)

return result

triplet_list_test = [
{
'instruction': '分类以下句子的情感为伤心、高兴、正常。',
'input': '今天中午吃什么?',
'output': '正常'
},
{
'instruction': '分类以下句子的情感为伤心、高兴、正常。',
'input': '我的科三挂了。',
'output': '伤心'
},
{
'instruction': '分类以下句子的情感为伤心、高兴、正常。',
'input': '今天上班被老板骂了。',
'output': '高兴'
},
]
triplet_task_list_test = ['情感分类']*3

get_One_Shot_Example_Validity(triplet_list_test, triplet_task_list_test, model, tokenizer, 2)
3it [00:00, 425.99it/s]
100%|██████████| 6/6 [00:00<00:00, 115.09it/s]
100%|██████████| 6/6 [00:00<00:00, 115.74it/s]





[1.0, 0.0, 0.0]

这个指标更偏向于评估输入输出,而不是指令。

更正

以上代码可以会出现计算结果为nan的情况,那是因为output只有一个token,因为label shift,最后一个token无法参与loss计算,所以需要在output之后加入一个token,什么token应该不影响

1
output = triplet["output"]

改为

1
output = triplet["output"] + tokenizer.eos_token
Comments