cuda 显存占用优化

Peng Xia

记录下 transformers 模型简单的减少显存的方式

  • 半精度加载
  • 混合精度训练
  • 激活 checkpoint
  • lora
  • flash attention

但不包含例如 FSDP 和 deepspeed 之类的相关代码,因为我是在 jupyter notebook 里实验代码的。

显存中分为两部分:

  • allocated: 当前正在使用中的显存(张量等活跃对象),例如模型参数、梯度、优化器状态、活跃变量等
  • reserved: 分配给 PyTorch CUDA 内存池,但暂时未使用的空间,例如中间变量释放后留下的内存;未来训练可能会复用,除非用 torch.cuda.empty_cache() 主动释放。

非训练相关代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import json, torch
from copy import deepcopy
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq, Trainer, TrainingArguments

model_name_or_path = "Qwen/Qwen2-0.5B-Instruct"

def get_memory_usage():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
max_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
max_reserved = torch.cuda.max_memory_reserved(device) / (1024 ** 3)

print(f"[Current] Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
print(f"[Peak] Max Allocated: {max_allocated:.2f} GB, Max Reserved: {max_reserved:.2f} GB")
else:
print("CUDA not available")


get_memory_usage()
1
2
3
4
5
6
7
8
dataset = load_dataset("trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness", split="train[:100]")
# print(json.dumps(dataset[0], indent=2, ensure_ascii=False))
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# 默认模板会为每个对话加上 system prompt,但我们处理数据时是分别对每个 message 进行处理的,所以这里不需要再加上 system prompt。
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def preprocess(example):

input_ids = []
labels = []

system_prompt = [{"role": "system", "content": "You are a helpful assistant."}]
prompt = example["prompt"]
completion = example["completion"]

text = tokenizer.apply_chat_template(system_prompt, tokenize=False)
encode_ids = tokenizer.encode(text)
input_ids.extend(encode_ids)
labels.extend([-100] * len(encode_ids))

text = tokenizer.apply_chat_template(prompt, tokenize=False)
encode_ids = tokenizer.encode(text)
input_ids.extend(encode_ids)
labels.extend([-100] * len(encode_ids))

text = tokenizer.apply_chat_template(completion, tokenize=False)
encode_ids = tokenizer.encode(text)
input_ids.extend(encode_ids)
target_encode_ids = deepcopy(encode_ids)
target_encode_ids[:3] = [-100] * 3 # 前三个 token, '<|im_start|>user\n', 不需要计算 loss
labels.extend(target_encode_ids)

attention_mask = [1] * len(input_ids)

return dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)

tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)

data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding='longest',
max_length=512,
return_tensors="pt",
)

vanilla

以下是很举出的训练代码,几乎没有指定参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
# bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)

print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 1.84 GB, Reserved: 1.95 GB
[Peak] Max Allocated: 1.84 GB, Max Reserved: 1.95 GB

After training:
[Current] Allocated: 5.61 GB, Reserved: 20.54 GB
[Peak] Max Allocated: 19.26 GB, Max Reserved: 20.54 GB

当没有指定模型数据类型时,以 float32 加载的,所以 allocated = 0.5 billion * 4 B (float32是四字节) = 2GB。训练后,优化器的一二阶参数 + 模型本体 = 3 * 0.5 billion * 4 B = 6GB, 训练完后梯度已经被清除了。

half-precision

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
)

training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
# bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
print(trainer.args.bf16)


print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB

After training:
[Current] Allocated: 2.80 GB, Reserved: 16.16 GB
[Peak] Max Allocated: 12.05 GB, Max Reserved: 16.16 GB

用半精度 bf16 加载后,看显存占用,整体都是用 bf16。

mixed precision training

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
)

training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
print(trainer.args.bf16)

print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB

After training:
[Current] Allocated: 2.80 GB, Reserved: 16.16 GB
[Peak] Max Allocated: 12.05 GB, Max Reserved: 16.16 GB

gradient checkpointing

gradient_checkpointing=True:是一种训练节省显存的技术,通过丢弃激活、在反向传播时重新计算来减少显存。
原先中间激活都被保留,启动 gradient_checkpointing 后,仅会保留部分中间激活。
use_cache=True:在模型前向传播中缓存 past_key_values,用于加速生成(如推理时的 decoder caching)。gradient_checkpointinguse_cache 并不兼容
可通过以下代码看到哪些层被 checkpint。

1
2
3
for name, module in model.named_modules():
if hasattr(module, "gradient_checkpointing") or "Block" in name:
print(name, type(module))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model.config.use_cache = False # 显式禁用缓存
model.gradient_checkpointing_enable() # 启用梯度检查点


training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)

print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 1.84 GB, Reserved: 1.95 GB
[Peak] Max Allocated: 1.84 GB, Max Reserved: 1.95 GB

After training:
[Current] Allocated: 5.55 GB, Reserved: 15.49 GB
[Peak] Max Allocated: 12.17 GB, Max Reserved: 15.49 GB

用了 checkpoint 后,激活显存占用少了近 1/4。

lora

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

from peft import get_peft_model, LoraConfig, TaskType
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"], # Qwen 使用q_proj/v_proj
)
model = get_peft_model(model, peft_config)

training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)

print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 1.84 GB, Reserved: 1.95 GB
[Peak] Max Allocated: 1.84 GB, Max Reserved: 1.95 GB

After training:
[Current] Allocated: 1.86 GB, Reserved: 17.88 GB
[Peak] Max Allocated: 13.29 GB, Max Reserved: 17.88 GB

训练前后 allocated 都没有大的变化,因为 lora 训练参数量非常少,因此对应的优化器参数量也少,但对激活也稍微有点影响,减了一点。

flahs attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)

training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
# bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)

print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB

After training:
[Current] Allocated: 2.80 GB, Reserved: 15.93 GB
[Peak] Max Allocated: 11.62 GB, Max Reserved: 15.93 GB

flash attention 也减少了一点激活的显存占用。

all together

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)

model.config.use_cache = False # 显式禁用缓存
model.gradient_checkpointing_enable() # 启用梯度检查点

from peft import get_peft_model, LoraConfig, TaskType
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"], # Qwen 使用q_proj/v_proj
)
model = get_peft_model(model, peft_config)

training_args = TrainingArguments(
output_dir="./train_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=6,
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
# bf16=True, # 启用 float16 混合精度
deepspeed=None,
report_to=[],
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)

print("Before training:")
get_memory_usage()

trainer.train()

print("After training:")
get_memory_usage()

Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB

After training:
[Current] Allocated: 0.95 GB, Reserved: 13.53 GB
[Peak] Max Allocated: 6.02 GB, Max Reserved: 13.53 GB

将上述方法结合起来,peak 显存占用在模型方面减少 2/3,激活减少 3/10。

Comments