快速入门[10]


# 模型
model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args,
                                                     model=model,
                                                     model_parameters=params)

~~torch.distributed.init_process_group(...)~~
deepspeed.init_distributed()

# 训练
for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

# load checkpoint
_, client_sd = model_engine.load_checkpoint(args.load_dir, args.ckpt_id)
step = client_sd['step']

#advance data loader to ckpt step
dataloader_to_step(data_loader, step + 1)

for step, batch in enumerate(data_loader):

    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

    #save checkpoint
    if step % args.save_interval:
        client_sd['step'] = step
        ckpt_id = loss.item()
        model_engine.save_checkpoint(args.save_dir, ckpt_id, client_sd = client_sd)

ZeRO 配置[1]

多节点 [10]

worker-1 slots=4
worker-2 slots=4
deepspeed --hostfile=myhostfile <client_entry.py> <client args> \\
  --deepspeed --deepspeed_config ds_config.json
  

deepspeed --num_nodes=2 \\
	<client_entry.py> <client args> \\
	--deepspeed --deepspeed_config ds_config.json  
  

deepspeed --exclude="worker-2:0@worker-3:0,1" \\
	<client_entry.py> <client args> \\
	--deepspeed --deepspeed_config ds_config.json 
	
deepspeed --include="worker-2:0,1" \\
	<client_entry.py> <client args> \\
	--deepspeed --deepspeed_config ds_config.json	 

多节点环境变量