0%

多GPU分布式训练

多GPU训练

Python’s GIL限制

Global Interpreter Lock

是一个互斥锁,用于保护对Python对象的访问,从而防止多个线程同时执行Python字节码。,确保任何时候只有一个线程在运行,无法将多个处理器的线程一起使用(使用multiprocessing库)

multiprocessing library

使用进程池 multiprocessing.Pool(p)

P为默认的CPU核心数

俺的另一篇博客,如何使用进程池。

线程安全

Python线程共享内存,访问数据取决于调度算法

GIL保证一次只运行一个线程

GIL瓶颈

许多I/O,图像处理和NumPy数学运算在GIL之外

  1. 阻碍多线程Cpython程序充分利用多处理器优势
  2. GIL内部花费大量时间来解释CPython字节码的多线程程序

并行的方式计算Loss

DataParallel(DP)

每个部分输出保留在其GPU上,而不是将所有部分输出收集到 master节点上。 我们还需要分配我们的损失准则计算,以便能够计算和反向传播我们的损失。

Multi Machine(DDP)

Distributed training (torch.nn.DistributedDataParallel)

即使在单机设置中,DistributedDataParallel仍可以有效地替换DataParallel。

在每个节点上独立启动python脚本,通过PyTorch分布式后端同步

each training script

  • 独立的optimizer :no parameter broadcast ( DataParallel is needed)
  • 独立的Python解释器:避免在单个Python解释器中驱动多个并行执行线程而产生GIL-freeze

调整每个机器(节点)的Python训练脚本

保证每个节点能分别运行脚本,每块GPU分配单独进程

  1. 初始化分布式后端以进行同步
  2. 封装模型并准备数据以在数据的单独子集上训练每个流程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

# 每个进程运行在local_rank参数指定的1个GPU设备上

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()

# 初始化分布式后端,同步节点/GPU

torch.distributed.init_process_group(backend='nccl')

# 模型封装在分配给当前进程的GPU

device = torch.device('cuda', arg.local_rank)
model = model.to(device)
distrib_model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank)

# 数据加载限制为当前进程专属的数据集子集

sampler = DistributedSampler(dataset)

dataloader = DataLoader(dataset, sampler=sampler)
for inputs, labels in dataloader:
predictions = distrib_model(inputs.to(device)) # 前传

loss = loss_function(predictions, labels.to(device)) # 计算损失

loss.backward() # 反传

optimizer.step() # Optimizer

多节点启动

torch.distributed.launch

可在每个训练节点上产生多个分布式训练进程 GPU 0 to GPU(nproc_per_node-1)

1
2
3
4
5
6
7
8
9
10
python -m torch.distributed.launch  
--nproc_per_node=NUM_GPUS_YOU_HAVE # 启动的进程数

--nnodes=2 # 节点数

--node_rank=0 # 当前节点编号

--master_addr="192.168.1.1" # 主节点IP

--master_port=1234 # 端口号