FSDP - Fully Sharded Data Parallel

Peng Xia

FSDP 是 DDP 的一种改进,它晚于 deepspeed 出现,借鉴了 deepspeed 的一个概念,sharding,即参数拆分。但是没有借鉴全,张量并行没有实现,但理解起来不难,结合下 NCCL 通信就可以知道他干什么和怎么实现的。记住,FSDP的两个关键点,sharding 和 data parallel。

sharding

在深度学习中,shard(分片)指的是把一个整体的数据结构或模型参数拆成多个部分,分别分布到不同设备或进程中处理或存储。通俗点,Sharding 就是“切块+分发”。

Sharding 的思想就是万物都可以分裂开,只要用的时候再复原,而且用完继续分裂开。这里的分裂不是常规意义上的分裂,更类似与分布式存储,每个节点都存储部分内容。FSDP 的前任是 DDP,DDP 的前提条件是每个 gpu 上都有一个模型副本,但随着模型不断变大,模型很难在单张 gpu 上部署,而且所有 gpu 都保存了大量重复的参数,也不够有效。

模型并行是一个解决方法:

  1. 将整个模型按层划分为多个连续的阶段(stage),每个阶段由一个设备负责计算。
  2. 在每个训练迭代开始时,第一个设备获取一个 batch 的输入数据,并执行前向计算。
  3. 第一个设备将计算出的中间激活值(activations)传递给第二个阶段的设备。
  4. 第二个设备接收到激活值后,基于这个输入继续执行前向计算,并将结果传递给下一个阶段,如此类推。
  5. 直到最后一个阶段完成前向计算,得到最终的输出。
  6. 基于输出,计算损失函数,并执行反向传播。
  7. 每个阶段在反向传播时,计算本阶段所需的梯度,并将上游梯度传递给前一个阶段。
  8. 所有阶段计算完成后,将各自的梯度汇总,更新对应的模型参数。
  9. 将更新后的模型参数分发到对应的设备,准备进行下一个 batch 的训练迭代。

但他还是以 layer 层面切分模型的,现在如果有超大的 layer,必然存在放不了一个超大的 layer 或者剩下的空间不足以放下一个 layer。

虽然模型并行和 sharding 都是只在多 gpu 上保留一份模型参数,但 sharding 直接分裂开 tensor,粒度更细,使得其更灵活。

Shard parameters, gradients, and optimizer states across all data-parallel processes

  • P: Parameters
  • G: Gradients
  • OS: Optimizer states

这是 deepspeed 介绍他们三种模式的图片

随着我们从 Optimizer 到 Gradients 到 Parameters 不断 sharding,每个 gpu 上的显存占用直线降低,而通讯压力仅有极小的成本上升。

sharding 的重要基础就是 NCCL 的通信机制,例如以下的通信方式:

  • all-gather,
  • all-reduce,
  • broadcast,
  • reduce,
  • reduce-scatter
  • as well as point-to-point send and receive

The NVIDIA Collective Communication Library (NCCL) implements multi-GPU and multi-node communication primitives optimized for NVIDIA GPUs and Networking.

在 DDP 中我们就用了 all-reduce 的概念,即每个 gpu 上都计算了分配给它的 mini batch 的梯度,为了让所有卡上都维持相同的模型,or梯度更新都一致,我们用 all-reduce 让所有 gpu 上的梯度都获得了所有 gpu 上的梯度。FSDP 在此基础上,让 FWD 和 BWD 过程都利用 all-reduce 实现了参数随用随 gather,用完即丢的效果。

sharding 示例

这里以一个 Linear 层为例,现在共有 4 个节点,我们把 linear 层平均的分散到每个 gpu 上,每个 gpu 只有一部分,不足以直接计算,现在 FWD 和 BWD 时,每个节点以 all-gather 的形式将所有节点上的 shard 收集为完整的 linear 层,这样每个节点都可以基于它自己的数据计算出输出,现在这个输出可以转到下一个层的计算,也是同样的 all-gather,而已经用的层,后续不再需要,因此直接释放,还是只保留原本的 shard。

FSDP 并不会对数据进行切分,因为它还是属于 data parallel 的范畴,即每个节点上运行部分数据,一个节点上的数据就交给这个节点计算,因此数据不会切分。

FSDP 本质上仍是一种数据并行(DP)策略,目标是解决单卡放不下完整模型的问题,同时保留传统 DDP 的易用性与训练逻辑。
不同于模型并行(MP)或流水线并行(PP)那种“切分计算”的思想,FSDP 并不切算子,而是切分参数存储(shard weight storage),从而在所有 GPU 上共同维护一个被 shard 化的模型。

!!! 注意,FSDP 不涉及张量并行或者模型并行的概念,它只用 sharding 来存储参数,用 gather 来随时收集分散的参数,此外和正常计算一样,计算方式上没有任何改变。

toy model

以下是一个 手动 sharding 和 gather shard 的 toy model。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

def cleanup():
dist.destroy_process_group()


def all_gather_tensor(tensor, world_size):
gathered = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(gathered, tensor)
full_tensor = torch.cat(gathered, dim=0)
return full_tensor


class LinearShard(nn.Module):
def __init__(self, in_features, out_features, rank, world_size):
super().__init__()
self.rank = rank
self.world_size = world_size

# 手动分 shard:每张卡只保存 out_features 的一部分
self.out_per_gpu = out_features // world_size
self.weight_shard = nn.Parameter(
torch.randn(self.out_per_gpu, in_features, device=f'cuda:{rank}')
)
self.bias_shard = nn.Parameter(
torch.randn(self.out_per_gpu, device=f'cuda:{rank}')
)

def forward(self, x):
# all_gather 所有 rank 的 weight 和 bias
full_weight = all_gather_tensor(self.weight_shard, self.world_size)
full_bias = all_gather_tensor(self.bias_shard, self.world_size)
print(f"Rank {self.rank} gathered weight shape: {full_weight.shape}, device: {full_weight.device}")
print(f"Rank {self.rank} gathered bias shape: {full_bias.shape}, device: {full_bias.device}")

# 所有卡执行完整 linear(广播来的 full_weight)
return torch.nn.functional.linear(x, full_weight, full_bias)

def demo(rank, world_size):
setup(rank, world_size)

# 模拟 Linear(4, 8),8维在两卡间 shard,每卡负责 4 个输出
model = LinearShard(in_features=4, out_features=8, rank=rank, world_size=world_size)
model = DDP(model, device_ids=[rank])

# forward
input = torch.randn(2, 4, device=f'cuda:{rank}')
output = model(input)

print("Output shape:", output.shape)

cleanup()

if __name__ == "__main__":
mp.spawn(demo, args=(2,), nprocs=2)

1
2
3
4
5
6
7
# CUDA_VISIBLE_DEVICES=1,2 python demo.py
Rank 1 gathered weight shape: torch.Size([8, 4]), device: cuda:1
Rank 1 gathered bias shape: torch.Size([8]), device: cuda:1
Rank 0 gathered weight shape: torch.Size([8, 4]), device: cuda:0
Rank 0 gathered bias shape: torch.Size([8]), device: cuda:0
Output shape: torch.Size([2, 8])
Output shape: torch.Size([2, 8])

FSDP

以图为例,重新分散模型:

  • unit 0: [layer 0, layer 3]
  • unit 1: [layer 1, layer 2]
  • unit 2: [layer 4, layer 5]

这里的 unit 即节点, sharding 本身没有任何顺序或者既定的比例,随便分散即可。每个节点执行 FWD 和 BWD 时都是先 gather 全部参数,计算,在释放非自己的 shard, BWD 同理,最后同步 gradient 。

再从计算和通信上看,在进行完 AG0 (all-gather layer 0)时可以同时 FWD0 和 AG1,此时通信和计算就会overlap,但不影响双方。

在 BWD 时也是 all-gather 完了进行 BWD 计算,这里直接进行了 reduce-scatter,因为 FSDP 还是类似 DDP 一样,要维持相同的模型或者梯度更新,这一步直接当前梯度计算完,直接同步。

FSDP 和 DDP 的区别还是很明显的,DDP 要在每个节点上保留一个完整的模型副本,模型比较大时就非常吃亏,因为大部分显存都被模型占着。而 FSDP 让所有节点共同存储一个模型,每个模型平摊代价,使得每个节点只需要承担 sub-model 级别的代价,使得可以跑更大的 batch 。

1
2
3
4
5
6
7
8
9
10
11
12
# 大概用 FSDP 包装下模型就可以了,当然也要设置类似 DDP 的参数

import torch
from torch.distributed._fsdp import FullyShardedDataParallel as FSDP

torch.cuda.set_device(device_id)

model = Net()
sharded_model = FSDP(model)
optim = torch.optim.Adam(sharded_model.parameters(), lr=0.0001)
sharded_model(input).sum().backward()
optim.step()

MNIST FSDP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


from torch.optim.lr_scheduler import StepLR

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)


# Distributed training setup

def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
dist.destroy_process_group()


class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):

x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output


def train(model, rank, train_loader, optimizer, epoch, sampler=None):
model.train()
ddp_loss = torch.zeros(2).to(rank)
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction='sum')
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)

dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))


def test(model, rank, test_loader):
model.eval()
correct = 0
ddp_loss = torch.zeros(3).to(rank)
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)

dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)

if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
100. * ddp_loss[1] / ddp_loss[2]))


def fsdp_main(rank, world_size, args):
setup(rank, world_size)

transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = datasets.MNIST('./mnist', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('./mnist', train=False,
transform=transform)

sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

# !!! 只有参数数量超过某个阈值的模块,才会被 FSDP 自动包裹(即切分)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
torch.cuda.set_device(rank)


model = Net().to(rank)

model = FSDP(model)

optimizer = optim.Adam(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

for epoch in range(1, args.epochs + 1):
train(model, rank, train_loader, optimizer, epoch, sampler=sampler1)
test(model, rank, test_loader)
scheduler.step()

if args.save_model:
# use a barrier to make sure training is done on all ranks
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")

cleanup()


if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()

torch.manual_seed(args.seed)

WORLD_SIZE = torch.cuda.device_count()
mp.spawn(fsdp_main,
args=(WORLD_SIZE, args),
nprocs=WORLD_SIZE,
join=True)

参考

Comments
On this page
FSDP - Fully Sharded Data Parallel