<aside> 💡
parameter sharding
</aside>
In DistributedDataParallel, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.
在 DistributedDataParallel (DDP) 训练中,每个进程/工作器都拥有模型的副本并处理一批数据,最后使用 all-reduce 对不同工作器上的梯度求和。在 DDP 中,模型权重和优化器状态在所有工作器之间复制。FSDP 是一种数据并行类型,它在 DDP 等级之间对模型参数、优化器状态和梯度进行分片。



FSDP forward pass:
for layer_i in layers:
all-gather full weights for layer_i
forward pass for layer_i
discard full weights for layer_i
FSDP backward pass:
for layer_i in layers:
all-gather full weights for layer_i
backward pass for layer_i
discard full weights for layer_i
reduce-scatter gradients for layer_i