<aside> 💡

parameter sharding

</aside>

FSDP 的工作原理 [2]

In DistributedDataParallel, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.

DistributedDataParallel (DDP) 训练中,每个进程/工作器都拥有模型的副本并处理一批数据,最后使用 all-reduce 对不同工作器上的梯度求和。在 DDP 中,模型权重和优化器状态在所有工作器之间复制。FSDP 是一种数据并行类型,它在 DDP 等级之间对模型参数、优化器状态和梯度进行分片。

fsdp_workflow.png

How FSDP works[4]

aorxyzay.bmp

kmt80h00.bmp

In pseudo-code:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        forward pass for layer_i
        discard full weights for layer_i

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        backward pass for layer_i
        discard full weights for layer_i
        reduce-scatter gradients for layer_i

FSDP[3]