欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch 分布式訓(xùn)練的實(shí)現(xiàn)

 更新時(shí)間:2025年05月15日 10:27:48   作者:handsomeboysk  
本文主要介紹了PyTorch 分布式訓(xùn)練的實(shí)現(xiàn),包括數(shù)據(jù)并行、模型并行、混合并行和流水線并行等模式,感興趣的可以了解一下

在深度學(xué)習(xí)模型變得日益龐大之后,單個(gè) GPU 的顯存已經(jīng)無(wú)法滿足高效訓(xùn)練的需求。此時(shí),“分布式訓(xùn)練(Distributed Training)”技術(shù)應(yīng)運(yùn)而生,成為加速訓(xùn)練的重要手段。

本文將圍繞以下三行典型的 PyTorch 分布式訓(xùn)練代碼進(jìn)行詳細(xì)解析,并擴(kuò)展介紹分布式訓(xùn)練的核心概念和實(shí)踐方法:

local_rank = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
global_rank = int(os.getenv('RANK', -1))
world_size = int(os.getenv('WORLD_SIZE', 1))

一、什么是分布式訓(xùn)練?

分布式訓(xùn)練是指將模型訓(xùn)練過(guò)程劃分到多個(gè)計(jì)算設(shè)備(通常是多個(gè) GPU,甚至是多臺(tái)機(jī)器)上進(jìn)行協(xié)同處理,目標(biāo)是加速訓(xùn)練速度擴(kuò)展模型容量。

分布式訓(xùn)練可以分為以下幾種模式:

  • 數(shù)據(jù)并行(Data Parallelism):每個(gè) GPU 處理不同的數(shù)據(jù)子集,同步梯度。
  • 模型并行(Model Parallelism):將模型拆成多個(gè)部分,分別部署到不同的 GPU。
  • 混合并行(Hybrid Parallelism):結(jié)合模型并行和數(shù)據(jù)并行。
  • 流水線并行(Pipeline Parallelism):按層切分模型,不同 GPU 處理不同階段。

二、理解分布式訓(xùn)練的核心概念

1. World Size(全局進(jìn)程數(shù))

world_size = int(os.getenv('WORLD_SIZE', 1))
  • 含義:分布式訓(xùn)練中,所有參與訓(xùn)練的進(jìn)程總數(shù)。通常等于 GPU 總數(shù)。
  • 作用:用于初始化進(jìn)程組(torch.distributed.init_process_group()),讓每個(gè)進(jìn)程知道集群的規(guī)模。

比如你有兩臺(tái)機(jī)器,每臺(tái) 4 塊 GPU,那么 world_size = 8。

2. Rank(全局進(jìn)程編號(hào))

global_rank = int(os.getenv('RANK', -1))
  • 含義:標(biāo)識(shí)每個(gè)訓(xùn)練進(jìn)程在所有進(jìn)程中的唯一編號(hào)(從 0 開(kāi)始)。
  • 作用:常用于標(biāo)記主節(jié)點(diǎn)(rank == 0),控制日志輸出、模型保存等。

例如:

  • rank=0:負(fù)責(zé)打印日志、保存模型
  • rank=1,2,…:只做訓(xùn)練

3. Local Rank(本地進(jìn)程編號(hào))

local_rank = int(os.getenv('LOCAL_RANK', -1))
  • 含義:當(dāng)前訓(xùn)練進(jìn)程在本地機(jī)器上的 GPU 編號(hào)。一般與 CUDA_VISIBLE_DEVICES 配合使用。

  • 作用:用于指定該進(jìn)程應(yīng)該使用哪塊 GPU,如:

    torch.cuda.set_device(local_rank)
    

三、環(huán)境變量的設(shè)置方式

這些環(huán)境變量通常由 分布式啟動(dòng)器 設(shè)置。例如使用 torchrun

torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr=192.168.1.1 --master_port=12345 train.py

torchrun 會(huì)自動(dòng)為每個(gè)進(jìn)程設(shè)置:

  • LOCAL_RANK
  • RANK
  • WORLD_SIZE

也可以手動(dòng)導(dǎo)出:

export WORLD_SIZE=8
export RANK=3
export LOCAL_RANK=3

四、分布式訓(xùn)練初始化流程(PyTorch 示例)

在 PyTorch 中,典型的初始化流程如下:

import os
import torch
import torch.distributed as dist

def setup_distributed():
    local_rank = int(os.getenv('LOCAL_RANK', -1))
    global_rank = int(os.getenv('RANK', -1))
    world_size = int(os.getenv('WORLD_SIZE', 1))

    torch.cuda.set_device(local_rank)

    dist.init_process_group(
        backend='nccl',  # GPU 用 nccl,CPU 用 gloo
        init_method='env://',
        world_size=world_size,
        rank=global_rank
    )
  • init_method='env://':表示從環(huán)境變量中讀取初始化信息。
  • nccl 是 NVIDIA 的高性能通信庫(kù),支持 GPU 間高速通信。

五、分布式訓(xùn)練的代碼結(jié)構(gòu)

使用 PyTorch 實(shí)現(xiàn)分布式訓(xùn)練的基本框架:

def train():
    setup_distributed()

    model = MyModel().cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

    dataset = MyDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=64)

    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        for batch in dataloader:
            # 正常訓(xùn)練流程

六、Elastic Training(彈性訓(xùn)練)

值得注意的是,示例代碼中注釋中提到的鏈接:https://pytorch.org/docs/stable/elastic/run.html

這是指 PyTorch 的 彈性分布式訓(xùn)練(Elastic Training),支持在訓(xùn)練過(guò)程中動(dòng)態(tài)增加或移除節(jié)點(diǎn),具備高容錯(cuò)性。

  • 工具:torch.distributed.elastic
  • 啟動(dòng)命令:torchrun --standalone --nnodes=1 --nproc_per_node=4 train.py

該特性對(duì)于大規(guī)模、長(zhǎng)時(shí)間訓(xùn)練任務(wù)非常重要。

七、總結(jié)

變量名含義來(lái)源典型用途
WORLD_SIZE全局進(jìn)程數(shù)量torchrun 設(shè)置初始化進(jìn)程組,全局通信
RANK當(dāng)前進(jìn)程的全局編號(hào)torchrun 設(shè)置控制主節(jié)點(diǎn)行為
LOCAL_RANK當(dāng)前進(jìn)程在本地的 GPU 編號(hào)torchrun 設(shè)置顯卡綁定:torch.cuda.set_device

這三行代碼雖然簡(jiǎn)單,卻是 PyTorch 分布式訓(xùn)練的入口。理解它們,就理解了 PyTorch 在分布式場(chǎng)景下的通信機(jī)制和訓(xùn)練框架。

如果你想要進(jìn)一步深入了解 PyTorch 分布式訓(xùn)練,推薦官方文檔:

到此這篇關(guān)于PyTorch 分布式訓(xùn)練的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)PyTorch 分布式訓(xùn)練內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家! 

相關(guān)文章

  • Python中使用HTMLParser解析html實(shí)例

    Python中使用HTMLParser解析html實(shí)例

    這篇文章主要介紹了Python中使用HTMLParser解析html實(shí)例,本文直接給出使用示例,并總結(jié)出HTMLParser含有的方法分為兩類,一類是需要顯式調(diào)用的,而另一類不需顯示調(diào)用,需要的朋友可以參考下
    2015-02-02
  • Python使用LSTM實(shí)現(xiàn)銷售額預(yù)測(cè)詳解

    Python使用LSTM實(shí)現(xiàn)銷售額預(yù)測(cè)詳解

    大家經(jīng)常會(huì)遇到一些需要預(yù)測(cè)的場(chǎng)景,比如預(yù)測(cè)品牌銷售額,預(yù)測(cè)產(chǎn)品銷量。本文給大家分享一波使用?LSTM?進(jìn)行端到端時(shí)間序列預(yù)測(cè)的完整代碼和詳細(xì)解釋,需要的可以參考一下
    2022-07-07
  • Python Django查詢集的延遲加載特性詳解

    Python Django查詢集的延遲加載特性詳解

    在 Django 的開(kāi)發(fā)過(guò)程中,查詢集(QuerySet)是我們與數(shù)據(jù)庫(kù)進(jìn)行交互的重要工具,本文將深入探討 Django 查詢集的延遲加載特性,幫助新手理解其工作原理及優(yōu)缺點(diǎn),提供一些實(shí)用的代碼示例來(lái)展示延遲加載如何在實(shí)際項(xiàng)目中使用,需要的朋友可以參考下
    2024-10-10
  • Python設(shè)計(jì)模式之代理模式實(shí)例詳解

    Python設(shè)計(jì)模式之代理模式實(shí)例詳解

    這篇文章主要介紹了Python設(shè)計(jì)模式之代理模式,結(jié)合實(shí)例形式較為詳細(xì)的分析了代理模式的概念、原理及Python定義、使用代理模式相關(guān)操作技巧,需要的朋友可以參考下
    2019-01-01
  • Python中self關(guān)鍵字的用法解析

    Python中self關(guān)鍵字的用法解析

    在Python中,self是一個(gè)經(jīng)常出現(xiàn)的關(guān)鍵字,特別是在類定義中的方法,這篇文章主要和大家self的作用和用法,希望可以幫助大家更好地理解為什么需要它以及如何正確使用它
    2023-11-11
  • python迭代器的使用方法實(shí)例

    python迭代器的使用方法實(shí)例

    這篇文章主要介紹了python迭代器的使用方法,代碼很簡(jiǎn)單,大家可以參考使用
    2013-11-11
  • Python編程判斷一個(gè)正整數(shù)是否為素?cái)?shù)的方法

    Python編程判斷一個(gè)正整數(shù)是否為素?cái)?shù)的方法

    這篇文章主要介紹了Python編程判斷一個(gè)正整數(shù)是否為素?cái)?shù)的方法,涉及Python數(shù)學(xué)運(yùn)算相關(guān)操作技巧,需要的朋友可以參考下
    2017-04-04
  • python實(shí)現(xiàn)單鏈表的方法示例

    python實(shí)現(xiàn)單鏈表的方法示例

    這篇文章主要給大家介紹了關(guān)于python實(shí)現(xiàn)單鏈表的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-09-09
  • 對(duì)于Python異常處理慎用“except:pass”建議

    對(duì)于Python異常處理慎用“except:pass”建議

    這篇文章主要介紹了對(duì)于Python異常處理方法的建議,摘選自StackOverflow上的熱門問(wèn)題的回答,闡述了except:pass的使用時(shí)需要注意的地方,需要的朋友可以參考下
    2015-04-04
  • python matplotlib如何給圖中的點(diǎn)加標(biāo)簽

    python matplotlib如何給圖中的點(diǎn)加標(biāo)簽

    這篇文章主要介紹了python matplotlib給圖中的點(diǎn)加標(biāo)簽,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-11-11

最新評(píng)論