LLM 推理 & speculative sampling

Peng Xia

推理基本方式

贪心解码

类似于分类器通常会选择概率最大的标签,对于文本生成任务,最直接的方法就是每次取概率最大的token_id,接下来我以贪心搜索为例介绍文本生成的流程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Encode initial input
input_text = "What is star war?"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(DEVICE) # Shape: [1, 4]

# Set the number of tokens to generate
num_tokens_to_generate = 100

# Iteratively generate tokens
# for _ in tqdm(range(num_tokens_to_generate), mininterval=1):
for _ in range(num_tokens_to_generate):

# Get model output logits
outputs = model(input_ids) # Shape: [1, current_length, 50257] or [batch_size, token length, vocab size]
logits = outputs.logits

'''
Predict the next token based on the last position
i.e., the i-th position logits is for predicting the i+1-th token
In this case, we want to predict the next token based on previous tokens, so we use the logits of the final token.
If you see the source code of forward function, you can notice the shifting of labels and logits for aligning.
'''
next_token_logits = logits[:, -1, :] # Shape: [1, 50257], corresponding to each vocab

'''
Greedy decoding: select the token with the highest probability
Supposily you can try top-k and beam search
'''
greedy_token_id = torch.argmax(next_token_logits, dim=-1) # Shape: [1]

# Append the predicted token to the input_ids
input_ids = torch.cat([input_ids, greedy_token_id.unsqueeze(-1)], dim=-1).to(DEVICE) # Shape: [1, current_length + 1]

# print(tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True))

# Decode the entire sequence of tokens
generated_text = tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True)
print("Generated Text:\n", generated_text)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# encode context the generation is conditioned on
model_inputs = tokenizer('I enjoy walking with my cute dog', return_tensors='pt').to(DEVICE)

pprint(model_inputs, width=100)

# generate 40 new tokens
# the output of generate is a `GenerateDecoderOnlyOutput` object, we only need the first attribute.
greedy_output = model.generate(**model_inputs,
max_new_tokens=40,
# max_length=50,
)

token_ids = torch.squeeze(greedy_output[0])
print(tokenizer.decode(token_ids, skip_special_tokens=True))

LLM的 logit 是 batch size * seq length * token num,在训练时,模型通常是用第 i-1 位的 logit 去预测第 i 位的 token id,所以预测时用最后一位的 logit 去预测新的 token id,结果softrmax后,我们得到新的 token id,然后按照 auto regressive 的风格,把新 token id加入到 之前的 token id 序列中, 再用相同的方法得到下一个 token id。这就是预测的整体流程, 贪心搜索只不过决定了新的 token id 的选择。

以图中的概率发布,解码结果位 ‘The cat is’。

在给定上下文生成的词语是合理的,但模型很快开始重复自己!这是语言生成中一个非常常见的问题,尤其在贪婪搜索和波束搜索中更为明显。然而,贪婪搜索的主要缺点是它会错过那些被低概率词遮挡的高概率词。其实是因为它的视野窗口只有1,因此无法做出偏长远的判断

top-k采样

在 top-k采样中,模型会筛选出 k 个最可能的下一 token,并将概率质量重新分配到这 K 个token,然后在他们之中随机采样。

Understanding LLM Decoding Strategies | by LM Po | Medium

  1. 计算token概率:在模型处理输入文本后,它会预测可能的下一个 token 的概率分布。

  2. 筛选 top-k:与考虑所有可能的 token 不同,top-k 采样将选择范围缩小到概率最高的 k 个 token 。这种“剪枝”减少了潜在输出空间,专注于最可能的下一 token ,而忽略了不太可能的选项。

  3. 随机采样:从 top-k token 中,重新分配他们的概率(一般再除以概率和,使得整体概率和仍对于1),根据它们的概率随机采样一个 token ,而不是总是选择最高概率的词元。这种方式引入了多样性,使生成的文本更加丰富多样。

1
2
3
4
5
6
7
8
topk_output = model.generate(**model_inputs, 
max_new_tokens=40,
do_sample=True,
top_k=50
)

token_ids = torch.squeeze(topk_output[0])
print(tokenizer.decode(token_ids, skip_special_tokens=True))

通过调整 k 的值,高 k 值(例如 50 或 100)允许更多的选择,增加了多样性和创造性,但可能降低连贯性。低 k 值(例如 5 或 10)限制了选项,通常使文本更具确定性和集中性,但有时也会导致内容过于重复或保守。

然而,top-k采样的一个问题是,它不会动态调整从下一个 token 的概率分布中过滤的 token 。这可能会导致问题,因为某些 token 可能来自非常陡峭的分布(即集中于少数词的分布,这时更容易加入很低概率的 token ,保留他们的意义不大),而其他词则来自更平坦的分(这种情况 token 的选择更正常)。例如图中左侧发布相对平坦,概率更均匀, top-6 随机采样时合理的,右侧发布主要集中在 top-3 中,如果取 top-6,后三个词的概率太小了,往往不可能选中,因此这样的概率在一开始就不需要成为候选。

因此,将采样池限制为固定大小 k 可能会导致模型在陡峭分布中产生无意义的内容,而在平坦分布中限制模型的创造力。

top-p采样

与仅从最可能的 k 个 token 中采样不同,top-p 采样选择的是累计概率超过阈值 p 的最小 token 集合。

它与 top-k 的区别仅在于筛选方式。top-k 采样选择个体概率最高的前 k 个 token ,而 top-p 采样则考虑累计概率至少为 p 的最小 token 集合,即概率较大的那些 token 。

1
2
3
4
5
6
7
8
topp_output = model.generate(**model_inputs, 
max_new_tokens=40,
do_sample=True,
top_p=0.92
)

token_ids = torch.squeeze(topp_output[0])
print(tokenizer.decode(token_ids, skip_special_tokens=True))

在开放式语言生成中,top-p 和 top-k 采样似乎比传统的贪心搜索和 beam search 生成更流畅的文本。但有人证明,贪心搜索和 beam search 的明显缺陷(主要是生成重复词序列)是由模型本身(尤其是模型的训练方式)引起的,而不是生成方法的问题

temperature

采样方法过程中都涉及到对部分数据采样,但通常都会把他们原始的概率进行调整。概率调整时的重要参数就是 temperature,它在文本生成中用于控制生成内容的随机性或创造性。通过调整可能生成的下一个 token 的概率分布,temperature 参数可以让模型生成的文本更保守或更具有创意。

temperature 会对每个可能生成的下一个 token 的概率分布进行缩放。temperature 越高,分布越“平”,即模型对每个候选 token 的选择倾向性降低,更可能从更大范围的 token 中随机选取下一个 token 。temperature 越低,分布越“尖”,模型会更偏向选择高概率的候选 token ,这样生成的文本就会更可预测和连贯。

speculative sampling

speculative sampling 基本思想时利用小模型先自回归生成,然后大模型批量评估,如果接受则继续生成,如果不接受,利用大模型评估的结果重新计算对应位置的token。

image-20250630002233791

步骤如下:

  1. draft model 基于 prefix 逐个生成 ,其中 是小模型生成的长度,同时也要返回对于位置的分布概率
  2. 用 target model 基于 为 input_ids 生成对应位置的发布概率
  3. 对于每个位置,以 target model 的 token 概率与 draft model 的 token 概率为接受率 ,取一个随机数,如果随机了大于此概率则接受,这个位置的 token 就是 ,否则就是不接受,那这个位置的 token 就直接基于 来推理,之后的 token 就没法推理了,因为此后 target model 的发布概率是基于不被接受的 token 产生的。
  4. 如果所有 token 都被接受,那可以利用 target model 产生的最后一位的发布概况,直接推理出新的token

image-20250630003324598

原论文有半页内容解释为什么这样子推理的概率发布等同于 target model 的概率发布,很简单,不解释了。

至于为什么快了,因为推理的时间成本在于自回归的解码,只有先生成上一个位置的 token,才能把 token 加入然后生成新的 token。模型越大,生成每一个 token 的时间越长。speculative sampling 就是让小模型先生成,再大模型验证,而且大模型的发布概率是不需要一次一次计算的,是基于一段新的token 批量计算的,因此会快。

但 speculative sampling 不能与 top k 和 top p 共同设置,因为接受率 需要保证 非零,而 top p 和 top k 是会将大部分概率赋值为 0,因此就会冲突。而 temperature 是缩放,因此不受影响,可以共同使用。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch, time
from tqdm import tqdm, trange
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = '../../DC/qwen2-1.5b-ins'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map='auto',
)

huggingface model generate 函数

正常推理可以直接用 model.generate 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
text = 'Do not go gentle into that good night,\nOld age should burn and rave at close of day;'

input_ids = tokenizer(text, return_tensors='pt').to(model.device)
start = time.time()
with torch.no_grad():
outputs = model.generate(
**input_ids,
max_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(f'Inference time: {time.time() - start:.2f} seconds')

1
2
3
Do not go gentle into that good night,
Old age should burn and rave at close of day; The wind is still, the rain is falling, The stars are all bright and bright. But I
Inference time: 0.64 seconds

naive 贪心解码 (带kv cache)

在调用 model.forward(input_ids=input_ids, past_key_values=past_kv) 时:input_ids 中应该只包含“新输入的 token”,也就是**还未被缓存(未进入 past_key_values)**的 token。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def autoregressive_generate_by_greedy_search(
model,
tokenizer,
text,
max_new_tokens: int = 10,
use_past_key_values: bool = True
):
device = model.device
input_ids = tokenizer(text, return_tensors='pt')['input_ids'].to(device) # 初始输入 [B=1,S]

with torch.no_grad():
start = time.time()
past_key_values = None

for _ in trange(max_new_tokens):
if use_past_key_values and past_key_values is not None:
# 仅输入最近生成的 token
input_ids_step = input_ids[:, -1:] # shape: [B, 1]
else:
# 第一步或不使用缓存:输入当前全部输入(会随着 input_ids 增长)
input_ids_step = input_ids

outputs = model(input_ids=input_ids_step, past_key_values=past_key_values if use_past_key_values else None)
logits = outputs.logits # shape: [B, S, V]
past_key_values = outputs.past_key_values if use_past_key_values else None

logits_last_token = logits[:, -1, :] # shape: [B, V]
next_token_id = logits_last_token.argmax(dim=-1, keepdim=True) # shape: [B, 1]

# 更新 input_ids 以便下一轮使用
input_ids = torch.cat([input_ids, next_token_id], dim=-1)

print(f'Inference time: {time.time() - start:.2f} seconds')

generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text

text = 'Do not go gentle into that good night,\nOld age should burn and rave at close of day;'
generated_text = autoregressive_generate_by_greedy_search(model, tokenizer, text, max_new_tokens=20, use_past_key_values=True)
print(generated_text)

generated_text = autoregressive_generate_by_greedy_search(model, tokenizer, text, max_new_tokens=20, use_past_key_values=False)
print(generated_text)

1
2
3
4
5
6
7
8
9
10
11
100%|██████████| 20/20 [00:00<00:00, 26.17it/s]
Inference time: 0.77 seconds
Do not go gentle into that good night,
Old age should burn and rave at close of day; The night is not set for us between wind and falling leaves.
A good night is not enough.
100%|██████████| 20/20 [00:00<00:00, 33.57it/s]
Inference time: 0.60 seconds
Do not go gentle into that good night,
Old age should burn and rave at close of day; The night is not set for us between wind and falling leaves.
A good night is not enough.

naive top k & top p & temperature 解码 (带 kv cache)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def apply_top_k_filter(probs: torch.Tensor, top_k: int = 0):
"""
将 probs 中非 top k 的值置为 -inf
"""
if top_k <= 0:
return probs
vocab_size = probs.size(-1)
top_k = min(top_k, vocab_size)
# 获得 top k, 将非 top k 的 probs 置为 0
top_k_values, top_k_indices = torch.topk(probs, top_k)
min_top_k_values = top_k_values[:, -1].unsqueeze(-1)
indices_to_remove = probs < min_top_k_values
probs = probs.masked_fill(indices_to_remove, 0)
return probs


def apply_top_p_filter(probs: torch.Tensor, top_p: float = 0.0):
"""
将 probs top 累积概率超过 top_p 的值置保留,其他置为 0
"""
if top_p <= 0.0:
return probs

# 排序 probs
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
# 计算累计概率
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 计算累计概率 > top p 的掩码,需要前移一位,最大概率位始终要保留
remove_mask = cumulative_probs > top_p
remove_mask[..., 1:] = remove_mask[..., :-1].clone()
remove_mask[..., 0] = 0
indices_to_remove = remove_mask.scatter(1, sorted_indices, remove_mask)
probs = probs.masked_fill(indices_to_remove, 0)

return probs

def apply_temperature(logits: torch.Tensor, temperature: float = 1.0):
"""
将 logits 除以 temperature, temperature 越大越平滑,越小越尖锐
"""
if temperature <= 0.0:
raise ValueError("Temperature must be greater than 0.")
return logits / temperature

def re_normalize_probs(probs: torch.Tensor):
"""
将 top p 和 top k 过滤后的概率重新归一化为概率分布
"""
prob_sums = torch.sum(probs, dim=-1, keepdim=True)
probs = probs / prob_sums
return probs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def autoregressive_generate_by_sampling(
model,
tokenizer,
text,
max_new_tokens:int=10,
top_k:int=10,
top_p:float=0.9,
temperature:float=1.0,
use_past_key_values:bool=True
):
"""
使用采样方法生成文本
"""
device = model.device
input_ids = tokenizer(text, return_tensors='pt')['input_ids'].to(device) # [B=1,S]

with torch.no_grad():

start = time.time()
past_key_values = None

for _ in trange(max_new_tokens):

if use_past_key_values and past_key_values is not None:
# 仅输入最近生成的 token
input_ids_step = input_ids[:, -1:] # shape: [B, 1]
else:
# 第一步或不使用缓存:输入所有 tokens
input_ids_step = input_ids

outputs = model(input_ids=input_ids_step, past_key_values=past_key_values if use_past_key_values else None)
logits = outputs.logits # shape: [B, S, V]
past_key_values = outputs.past_key_values if use_past_key_values else None

logits_last_token = logits[:, -1, :]
logits_last_token = apply_temperature(logits_last_token, temperature)
probs = F.softmax(logits_last_token, dim=-1) # [B,S,V]
probs = apply_top_k_filter(probs, top_k)
probs = apply_top_p_filter(probs, top_p)
probs = re_normalize_probs(probs)
next_token_id = torch.multinomial(probs, num_samples=1) # [B,S,V]

input_ids = torch.cat([input_ids, next_token_id], dim=-1)

print(f'Inference time: {time.time() - start:.2f} seconds')

generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text

text = 'Do not go gentle into that good night,\nOld age should burn and rave at close of day;'
generated_text = autoregressive_generate_by_sampling(model, tokenizer, text, max_new_tokens=20, use_past_key_values=True)
print(generated_text)

generated_text = autoregressive_generate_by_sampling(model, tokenizer, text, max_new_tokens=20, use_past_key_values=False)
print(generated_text)
1
2
3
4
5
6
7
8
9
100%|██████████| 20/20 [00:00<00:00, 30.06it/s]
Inference time: 0.67 seconds
Do not go gentle into that good night,
Old age should burn and rave at close of day; Old age might’s well ask what’s in store.
For we are not simply dead, we are
100%|██████████| 20/20 [00:00<00:00, 32.77it/s]
Inference time: 0.61 seconds
Do not go gentle into that good night,
Old age should burn and rave at close of day;The summer bugle grew and the day but faltered;The roosting-dove peeps

speculative sampling

1
2
3
4
5
6
7
8
9
10
11
12
13
target_model = AutoModelForCausalLM.from_pretrained(
'../../DC/opt-1.3b',
torch_dtype=torch.bfloat16,
device_map='auto',
)


tokenizer = AutoTokenizer.from_pretrained('../../DC/opt-1.3b')
draft_model = AutoModelForCausalLM.from_pretrained(
'../../DC/opt-6.7b',
torch_dtype=torch.bfloat16,
device_map='auto',
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def truncate_kv_cache(past_key_values, new_length):
new_past = []
for layer_past in past_key_values:
k, v = layer_past
k = k[:, :, :new_length, :]
v = v[:, :, :new_length, :]
new_past.append((k, v))
return new_past


def sample_from_draft_model(model, input_ids, past_key_values, max_new_tokens=4, temperature=1.0):
"""
从模型中采样生成文本
对于 draft 模型,需要返回 output 和 prob,前者用于计算 target model 的 prob,后者用于判断是否保留
"""
device = model.device
input_ids = input_ids.to(device) # [B=1,S]

collected_probs = []

with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model(input_ids, past_key_values=past_key_values)

past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]

next_token_logits = apply_temperature(next_token_logits, temperature)
next_token_probs = torch.softmax(next_token_logits, dim=-1)
next_token_id = torch.multinomial(next_token_probs, num_samples=1)

input_ids = torch.cat([input_ids, next_token_id], dim=-1)
collected_probs.append(next_token_probs)

draft_new_token_probs = torch.cat(collected_probs, dim=0)

return input_ids, draft_new_token_probs, past_key_values


def get_target_model_probs(model, input_ids, num_new_tokens=4, past_key_values=None):
"""
获取 target 模型的 logits
"""
device = model.device
input_ids = input_ids.to(device) # [B=1,S]

with torch.no_grad():
outputs = model(input_ids, past_key_values=past_key_values)
logits = outputs.logits
new_token_logits = logits[0, -num_new_tokens-1:, :]
new_token_logits = apply_temperature(new_token_logits, temperature=1.0)
new_token_probs = torch.softmax(new_token_logits, dim=-1)

return new_token_probs[:num_new_tokens], new_token_probs[-1], outputs.past_key_values


text = 'Do not go gentle into that good night,\nOld age should burn and rave at close of day;'
input_ids = tokenizer(text, return_tensors='pt')['input_ids'].to(draft_model.device)

max_new_tokens = 20
lookahead = 10
temperature = 1.0
epsilon = 1e-9

draft_past_key_values = None
target_past_key_values = None

target_length = input_ids.shape[1] + max_new_tokens

generated_ids = input_ids.clone()

while generated_ids.shape[1] < len(input_ids[0]) + max_new_tokens:

prompt_len = generated_ids.shape[1]

draft_input_ids, draft_new_token_probs, draft_past_key_values = sample_from_draft_model(
draft_model, generated_ids, draft_past_key_values, max_new_tokens=lookahead, temperature=temperature
)

target_new_token_probs, target_fianl_token_probs, target_past_key_values = get_target_model_probs(
target_model, draft_input_ids, num_new_tokens=lookahead, past_key_values=target_past_key_values
)

num_accepted = 0
all_accepted = True

for i in range(lookahead):

drafted_token_id = draft_input_ids[0, prompt_len + i]

p_val = target_new_token_probs[i][drafted_token_id.item()]
q_val = draft_new_token_probs[i][drafted_token_id.item()]

ratio = p_val / (q_val + epsilon) # Add epsilon to avoid division by zero
random_num = torch.rand(1)

## Acceptance
if random_num.item() < ratio.item():

generated_ids = torch.cat([generated_ids, drafted_token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
num_accepted += 1

## Rejection
else:

all_accepted = False
new_dist = target_new_token_probs[i] - draft_new_token_probs[i]
new_dist = torch.max(torch.zeros_like(new_dist), new_dist)
new_dist = re_normalize_probs(new_dist)

resampled_token_id = torch.multinomial(new_dist, num_samples=1)
generated_ids = torch.cat([generated_ids, resampled_token_id.unsqueeze(0)], dim=-1)
break

if all_accepted:

final_token_id = torch.multinomial(target_fianl_token_probs, num_samples=1)
generated_ids = torch.cat([generated_ids, final_token_id.unsqueeze(0)], dim=-1)
num_accepted += 1

accepted_input_id_len = prompt_len + num_accepted

# target_past_key_values = truncate_kv_cache(target_past_key_values, new_length=accepted_input_id_len)
# draft_past_key_values = truncate_kv_cache(draft_past_key_values, new_length=accepted_input_id_len)
target_past_key_values = truncate_kv_cache(target_past_key_values, new_length=accepted_input_id_len-1)
draft_past_key_values = truncate_kv_cache(draft_past_key_values, new_length=accepted_input_id_len-1)


print(f"Accepted {num_accepted} tokens. Total length: {accepted_input_id_len}")

1
2
3
Accepted 1 tokens. Total length: 23
Accepted 11 tokens. Total length: 35
Accepted 9 tokens. Total length: 44
Comments