记录下 transformers 模型简单的减少显存的方式
- 半精度加载
- 混合精度训练
- 激活 checkpoint
- lora
- flash attention
但不包含例如 FSDP 和 deepspeed 之类的相关代码,因为我是在 jupyter notebook 里实验代码的。
显存中分为两部分:
- allocated: 当前正在使用中的显存(张量等活跃对象),例如模型参数、梯度、优化器状态、活跃变量等
- reserved: 分配给 PyTorch CUDA 内存池,但暂时未使用的空间,例如中间变量释放后留下的内存;未来训练可能会复用,除非用 torch.cuda.empty_cache() 主动释放。
非训练相关代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| import json, torch from copy import deepcopy from datasets import load_dataset from trl import SFTConfig, SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq, Trainer, TrainingArguments
model_name_or_path = "Qwen/Qwen2-0.5B-Instruct" def get_memory_usage(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) reserved = torch.cuda.memory_reserved(device) / (1024 ** 3) max_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3) max_reserved = torch.cuda.max_memory_reserved(device) / (1024 ** 3)
print(f"[Current] Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB") print(f"[Peak] Max Allocated: {max_allocated:.2f} GB, Max Reserved: {max_reserved:.2f} GB") else: print("CUDA not available")
get_memory_usage()
|
1 2 3 4 5 6 7 8
| dataset = load_dataset("trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness", split="train[:100]")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| def preprocess(example):
input_ids = [] labels = []
system_prompt = [{"role": "system", "content": "You are a helpful assistant."}] prompt = example["prompt"] completion = example["completion"]
text = tokenizer.apply_chat_template(system_prompt, tokenize=False) encode_ids = tokenizer.encode(text) input_ids.extend(encode_ids) labels.extend([-100] * len(encode_ids))
text = tokenizer.apply_chat_template(prompt, tokenize=False) encode_ids = tokenizer.encode(text) input_ids.extend(encode_ids) labels.extend([-100] * len(encode_ids))
text = tokenizer.apply_chat_template(completion, tokenize=False) encode_ids = tokenizer.encode(text) input_ids.extend(encode_ids) target_encode_ids = deepcopy(encode_ids) target_encode_ids[:3] = [-100] * 3 labels.extend(target_encode_ids)
attention_mask = [1] * len(input_ids)
return dict( input_ids=input_ids, labels=labels, attention_mask=attention_mask, )
tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, padding='longest', max_length=512, return_tensors="pt", )
|
vanilla
以下是很举出的训练代码,几乎没有指定参数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, )
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 1.84 GB, Reserved: 1.95 GB
[Peak] Max Allocated: 1.84 GB, Max Reserved: 1.95 GB
After training:
[Current] Allocated: 5.61 GB, Reserved: 20.54 GB
[Peak] Max Allocated: 19.26 GB, Max Reserved: 20.54 GB
当没有指定模型数据类型时,以 float32 加载的,所以 allocated = 0.5 billion * 4 B (float32是四字节) = 2GB。训练后,优化器的一二阶参数 + 模型本体 = 3 * 0.5 billion * 4 B = 6GB, 训练完后梯度已经被清除了。
half-precision
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, )
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, ) print(trainer.args.bf16)
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB
After training:
[Current] Allocated: 2.80 GB, Reserved: 16.16 GB
[Peak] Max Allocated: 12.05 GB, Max Reserved: 16.16 GB
用半精度 bf16 加载后,看显存占用,整体都是用 bf16。
mixed precision training
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, )
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", bf16=True, deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, ) print(trainer.args.bf16)
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB
After training:
[Current] Allocated: 2.80 GB, Reserved: 16.16 GB
[Peak] Max Allocated: 12.05 GB, Max Reserved: 16.16 GB
gradient checkpointing
gradient_checkpointing=True
:是一种训练节省显存的技术,通过丢弃激活、在反向传播时重新计算来减少显存。
原先中间激活都被保留,启动 gradient_checkpointing 后,仅会保留部分中间激活。
use_cache=True
:在模型前向传播中缓存 past_key_values
,用于加速生成(如推理时的 decoder caching)。gradient_checkpointing
和 use_cache
并不兼容
可通过以下代码看到哪些层被 checkpint。
1 2 3
| for name, module in model.named_modules(): if hasattr(module, "gradient_checkpointing") or "Block" in name: print(name, type(module))
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| model = AutoModelForCausalLM.from_pretrained(model_name_or_path) model.config.use_cache = False model.gradient_checkpointing_enable()
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", bf16=True, deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, )
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 1.84 GB, Reserved: 1.95 GB
[Peak] Max Allocated: 1.84 GB, Max Reserved: 1.95 GB
After training:
[Current] Allocated: 5.55 GB, Reserved: 15.49 GB
[Peak] Max Allocated: 12.17 GB, Max Reserved: 15.49 GB
用了 checkpoint 后,激活显存占用少了近 1/4。
lora
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
from peft import get_peft_model, LoraConfig, TaskType peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"], ) model = get_peft_model(model, peft_config)
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", bf16=True, deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, )
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 1.84 GB, Reserved: 1.95 GB
[Peak] Max Allocated: 1.84 GB, Max Reserved: 1.95 GB
After training:
[Current] Allocated: 1.86 GB, Reserved: 17.88 GB
[Peak] Max Allocated: 13.29 GB, Max Reserved: 17.88 GB
训练前后 allocated 都没有大的变化,因为 lora 训练参数量非常少,因此对应的优化器参数量也少,但对激活也稍微有点影响,减了一点。
flahs attention
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| model = AutoModelForCausalLM.from_pretrained( model_name_or_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, )
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, )
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB
After training:
[Current] Allocated: 2.80 GB, Reserved: 15.93 GB
[Peak] Max Allocated: 11.62 GB, Max Reserved: 15.93 GB
flash attention 也减少了一点激活的显存占用。
all together
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| model = AutoModelForCausalLM.from_pretrained( model_name_or_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, )
model.config.use_cache = False model.gradient_checkpointing_enable()
from peft import get_peft_model, LoraConfig, TaskType peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"], ) model = get_peft_model(model, peft_config)
training_args = TrainingArguments( output_dir="./train_output", per_device_train_batch_size=2, gradient_accumulation_steps=6, num_train_epochs=1, logging_steps=1, save_strategy="no", deepspeed=None, report_to=[], )
trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, )
print("Before training:") get_memory_usage()
trainer.train()
print("After training:") get_memory_usage()
|
Before training:
[Current] Allocated: 0.93 GB, Reserved: 0.97 GB
[Peak] Max Allocated: 0.93 GB, Max Reserved: 0.97 GB
After training:
[Current] Allocated: 0.95 GB, Reserved: 13.53 GB
[Peak] Max Allocated: 6.02 GB, Max Reserved: 13.53 GB
将上述方法结合起来,peak 显存占用在模型方面减少 2/3,激活减少 3/10。