#!/bin/bash# Number of Nodes#SBATCH --nodes=3# Number of tasks. 3 (1 per node)#SBATCH --ntasks=3# Number of GPU per node#SBATCH --gres=gpu:4#SBATCH --gpus-per-node=4# 16 CPUs per node#SBATCH --cpus-per-gpu=4# 16Go per nodes (4Go per GPU)#SBATCH --mem=16G# we need all nodes to be ready at the same time#SBATCH --wait-all-nodes=1# Total resources:# CPU: 16 * 3 = 48# RAM: 16 * 3 = 48 Go# GPU: 4 * 3 = 12# Setup our rendez-vous pointRDV_ADDR=$(hostname)WORLD_SIZE=$SLURM_JOB_NUM_NODES# -----srun--labeltorchrun\--nproc_per_node=$SLURM_GPUS_PER_NODE\--nnodes=$WORLD_SIZE\--rdzv_id=$SLURM_JOB_ID\--rdzv_backend=c10d\--rdzv_endpoint=$RDV_ADDR\training_script.py
You can find below a pytorch script outline on what a multi-node trainer could look like.
importosimporttorch.distributedasdistclassTrainer:def__init__(self):self.local_rank=Noneself.chk_path=...self.model=...@propertydefdevice_id(self):returnself.local_rankdefload_checkpoint(self,path):self.chk_path=path# ...defshould_checkpoint(self):# Note: only one worker saves its weightsreturnself.global_rank==0andself.local_rank==0defsave_checkpoint(self):ifself.chk_pathisNone:return# Save your states here# Note: you should save the weights of self.model not ddp_model# ...definitialize(self):self.global_rank=int(os.environ.get("RANK",-1))self.local_rank=int(os.environ.get("LOCAL_RANK",-1))assertself.global_rank>=0,'Global rank should be set (Only Rank 0 can save checkpoints)'assertself.local_rank>=0,'Local rank should be set'dist.init_process_group(backend="gloo|nccl")defsync_weights(self,resuming=False):ifresuming:# in the case of resuming all workers need to load the same checkpointself.load_checkpoint()# Wait for everybody to finish loading the checkpointdist.barrier()return# Make sure all workers have the same initial weights# This makes the leader save his weightsifself.should_checkpoint():self.save_checkpoint()# All workers wait for the leader to finishdist.barrier()# All followers load the leader's weightsifnotself.should_checkpoint():self.load_checkpoint()# Leader waits for the follower to load the weightsdist.barrier()defdataloader(self,dataset,batch_size):train_sampler=ElasticDistributedSampler(dataset)train_loader=DataLoader(dataset,batch_size=batch_size,num_workers=4,pin_memory=True,sampler=train_sampler,)returntrain_loaderdeftrain_step(self):# Your batch processing step here# ...passdeftrain(self,dataset,batch_size):self.sync_weights()ddp_model=torch.nn.parallel.DistributedDataParallel(self.model,device_ids=[self.device_id],output_device=self.device_id)loader=self.dataloader(dataset,batch_size)forepochinrange(100):forbatchiniter(loader):self.train_step(batch)ifself.should_checkpoint():self.save_checkpoint()defmain():trainer=Trainer()trainer.load_checkpoint(path)tainer.initialize()trainer.train(dataset,batch_size)
Note
To bypass Python GIL (Global interpreter lock) pytorch spawn one process for each GPU.
In the example above this means at least 12 processes are spawn, at least 4 on each node.