NCCL深度學(xué)習(xí)Bootstrap網(wǎng)絡(luò)連接建立源碼解析
引言
上次介紹到rank0的機(jī)器生成了ncclUniqueId,并完成了機(jī)器的bootstrap網(wǎng)絡(luò)和通信網(wǎng)絡(luò)的初始化,這節(jié)接著看下所有節(jié)點(diǎn)間bootstrap的連接是如何建立的。
rank0節(jié)點(diǎn)執(zhí)行ncclGetUniqueId生成ncclUniqueId
通過mpi將Id廣播到所有節(jié)點(diǎn),然后所有節(jié)點(diǎn)都會(huì)執(zhí)行ncclCommInitRank,這里其他節(jié)點(diǎn)也會(huì)進(jìn)行初始化bootstrap網(wǎng)絡(luò)和通信網(wǎng)絡(luò)的操作,然后會(huì)執(zhí)行到ncclCommInitRankSync。
ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank, int cudaDev) {
ncclResult_t res;
CUDACHECK(cudaSetDevice(cudaDev));
NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup);
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);
NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);
INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %x - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId);
return ncclSuccess;
cleanup:
if ((*newcomm) && (*newcomm)->bootstrap) bootstrapAbort((*newcomm)->bootstrap);
*newcomm = NULL;
return res;
}
ncclComm_t是指向ncclComm的指針,ncclComm是一個(gè)大雜燴,包含了通信用到的所有上下文信息,里面的字段等用到的時(shí)候再介紹,然后通過commAlloc分配newcom,并且完成初始化,比如當(dāng)前是哪個(gè)卡,對(duì)應(yīng)的pcie busid是什么,
執(zhí)行initTransportsRank
static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* commId) {
// We use 3 AllGathers
// 1. { peerInfo, comm }
// 2. ConnectTransport[nranks], ConnectValue[nranks]
// 3. { nThreads, nrings, compCap, prev[MAXCHANNELS], next[MAXCHANNELS] }
int rank = comm->rank;
int nranks = comm->nRanks;
uint64_t commHash = getHash(commId->internal, NCCL_UNIQUE_ID_BYTES);
TRACE(NCCL_INIT, "comm %p, commHash %lx, rank %d nranks %d - BEGIN", comm, commHash, rank, nranks);
NCCLCHECK(bootstrapInit(commId, rank, nranks, &comm->bootstrap));
// AllGather1 - begin
struct {
struct ncclPeerInfo peerInfo;
struct ncclComm* comm;
} *allGather1Data;
NCCLCHECK(ncclCalloc(&allGather1Data, nranks));
allGather1Data[rank].comm = comm;
struct ncclPeerInfo* myInfo = &allGather1Data[rank].peerInfo;
NCCLCHECK(fillInfo(comm, myInfo, commHash));
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather1Data, sizeof(*allGather1Data)));
NCCLCHECK(ncclCalloc(&comm->peerInfo, nranks+1)); // Extra rank to represent CollNet root
for (int i = 0; i < nranks; i++) {
memcpy(comm->peerInfo+i, &allGather1Data[i].peerInfo, sizeof(struct ncclPeerInfo));
if ((i != rank) && (comm->peerInfo[i].hostHash == myInfo->hostHash) && (comm->peerInfo[i].busId == myInfo->busId)) {
WARN("Duplicate GPU detected : rank %d and rank %d both on CUDA device %x", rank, i, myInfo->busId);
return ncclInvalidUsage;
}
}
看下bootstrapInit
ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) {
ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;
bool idFromEnv = getenv("NCCL_COMM_ID") != NULL;
struct extState* state;
NCCLCHECK(ncclCalloc(&state, 1));
state->rank = rank;
state->nranks = nranks;
*commState = state;
TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks);
struct extInfo info = { 0 };
info.rank = rank;
info.nranks = nranks;
void *tmpSendComm, *tmpRecvComm;
// Pass the remote address to listen via info
if (idFromEnv) {
memcpy(&info.extHandleListen, netHandle, sizeof(ncclNetHandle_t));
memcpy(&info.extHandleListenRoot, netHandle, sizeof(ncclNetHandle_t));
}
// listen will return the local address via info (specify interface type 'findSubnetIf')
state->dev = idFromEnv ? findSubnetIf : 0;
void* extBstrapListenCommRoot;
NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListen, &state->extBstrapListenComm));
NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListenRoot, &extBstrapListenCommRoot));
// stagger connection times to avoid an overload of the root at very high rank counts
if (nranks > 128) {
long msec = rank;
struct timespec tv;
tv.tv_sec = msec / 1000;
tv.tv_nsec = 1000000 * (msec % 1000);
TRACE(NCCL_INIT, "rank %d delaying connection to root by %ld msec", rank, msec);
(void) nanosleep(&tv, NULL);
}
// send info on my listening socket to root
NCCLCHECK(bootstrapNetConnect(state->dev, netHandle, &tmpSendComm));
NCCLCHECK(bootstrapNetSend(tmpSendComm, &info, sizeof(info)));
NCCLCHECK(bootstrapNetCloseSend(tmpSendComm));
// get info on my "next" rank in the bootstrap ring from root
}
首先看下commState
即ncclComm的bootstrap,類型為extState。
struct extState {
void* extBstrapListenComm;
void* extBstrapRingRecvComm;
void* extBstrapRingSendComm;
ncclNetHandle_t* peerBstrapHandles;
struct unexConn* unexpectedConnections;
int rank;
int nranks;
int dev;
};
其中extBstrapRingSendComm是當(dāng)前節(jié)點(diǎn)連接next的socket連接,extBstrapRingRecvComm是當(dāng)前節(jié)點(diǎn)和prev節(jié)點(diǎn)的socket連接,extBstrapListenComm是當(dāng)前節(jié)點(diǎn)的監(jiān)聽socket,peerBstrapHandles是所有rank的ip port(對(duì)應(yīng)extBstrapListenComm),dev默認(rèn)為0,表示用第幾個(gè)ip地址。
然后通過bootstrapNetListen創(chuàng)建extHandleListen和extHandleListenRoot兩個(gè)bootstrap comm,如前文所述,bootstrap comm其實(shí)就是保存了fd,這里創(chuàng)建兩個(gè)comm的原因是extHandleListen是rank之間實(shí)際使用的bootstrap連接,extHandleListenRoot是rank0節(jié)點(diǎn)和其他所有rank進(jìn)行通信使用的連接。
static ncclResult_t bootstrapNetListen(int dev, ncclNetHandle_t* netHandle, void** listenComm)
bootstrapNetListen函數(shù)上節(jié)有介紹過,會(huì)獲取到第dev個(gè)當(dāng)前機(jī)器的ip,然后listen獲取監(jiān)聽fd,將ip port寫到nethandle,獲取到的bootstrap comm寫到listencomm。
然后將rank,nrank,extHandleListen和extHandleListenRoot寫到extInfo里。
struct extInfo {
int rank;
int nranks;
ncclNetHandle_t extHandleListenRoot;
ncclNetHandle_t extHandleListen;
};
netHandle為ncclUniqueId,即rank0的ip port,然后通過bootstrapNetConnect創(chuàng)建bootstrap send comm,類比bootstrapNetListen,bootstrapNetConnect就是建立到netHandle的socket連接,將socket寫到sendComm里,這里dev并沒有用到。
static ncclResult_t bootstrapNetConnect(int dev, ncclNetHandle_t* netHandle, void** sendComm)
然后通過bootstrapNetSend將extInfo發(fā)送出去,即發(fā)給rank0:
static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {
struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;
NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));
NCCLCHECK(socketSend(comm->fd, data, size));
return ncclSuccess;
}
其中socketSend就是執(zhí)行send接口發(fā)送數(shù)據(jù)。
然后通過bootstrapNetCloseSend關(guān)閉fd
rank0收到數(shù)據(jù)后會(huì)做什么工作呢,回顧一下,rank0的節(jié)執(zhí)行ncclGetUniqueId生成ncclUniqueId,其中在執(zhí)行bootstrapCreateRoot的最后會(huì)啟動(dòng)一個(gè)線程執(zhí)行bootstrapRoot。
static void *bootstrapRoot(void* listenComm) {
struct extInfo info;
ncclNetHandle_t *rankHandles = NULL;
ncclNetHandle_t *rankHandlesRoot = NULL; // for initial rank <-> root information exchange
ncclNetHandle_t zero = { 0 }; // for sanity checking
void* tmpComm;
ncclResult_t res;
setFilesLimit();
TRACE(NCCL_INIT, "BEGIN");
/* Receive addresses from all ranks */
int nranks = 0, c = 0;
do {
NCCLCHECKGOTO(bootstrapNetAccept(listenComm, &tmpComm), res, out);
NCCLCHECKGOTO(bootstrapNetRecv(tmpComm, &info, sizeof(info)), res, out);
NCCLCHECKGOTO(bootstrapNetCloseRecv(tmpComm), res, out);
if (c == 0) {
nranks = info.nranks;
NCCLCHECKGOTO(ncclCalloc(&rankHandles, nranks), res, out);
NCCLCHECKGOTO(ncclCalloc(&rankHandlesRoot, nranks), res, out);
}
if (nranks != info.nranks) {
WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nranks);
goto out;
}
if (memcmp(&zero, &rankHandlesRoot[info.rank], sizeof(ncclNetHandle_t)) != 0) {
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);
goto out;
}
// Save the connection handle for that rank
memcpy(rankHandlesRoot+info.rank, info.extHandleListenRoot, sizeof(ncclNetHandle_t));
memcpy(rankHandles+info.rank, info.extHandleListen, sizeof(ncclNetHandle_t));
++c;
TRACE(NCCL_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks);
} while (c < nranks);
TRACE(NCCL_INIT, "COLLECTED ALL %d HANDLES", nranks);
// Send the connect handle for the next rank in the AllGather ring
for (int r=0; r<nranks; ++r) {
int next = (r+1) % nranks;
void *tmpSendComm;
NCCLCHECKGOTO(bootstrapNetConnect(0, rankHandlesRoot+r, &tmpSendComm), res, out);
NCCLCHECKGOTO(bootstrapNetSend(tmpSendComm, rankHandles+next, sizeof(ncclNetHandle_t)), res, out);
NCCLCHECKGOTO(bootstrapNetCloseSend(tmpSendComm), res, out);
}
TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks);
out:
bootstrapNetCloseListen(listenComm);
if (rankHandles) free(rankHandles);
if (rankHandlesRoot) free(rankHandlesRoot);
TRACE(NCCL_INIT, "DONE");
return NULL;
}
listenComm是上一個(gè)博文中rank0創(chuàng)建的監(jiān)聽fd,bootstrapNetAccept是從listenComm中獲取一個(gè)新連接,使用新連接的fd創(chuàng)建recvcomm。
static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm)
然后通過bootstrapNetRecv讀取tmpComm的數(shù)據(jù),即其他rank發(fā)送來的extInfo,然后保存其他rank的extHandleListen和extHandleListenRoot,這個(gè)時(shí)候rank0就獲取到其他所有rank的ip和port了。
獲取完所有rank的info之后開始建環(huán),將節(jié)點(diǎn)(r+1) % nranks的extHandleListen發(fā)送給節(jié)點(diǎn)r,就是說將節(jié)點(diǎn)r的next節(jié)點(diǎn)的nethandle發(fā)送給節(jié)點(diǎn)r。這里可以看出,每個(gè)節(jié)點(diǎn)創(chuàng)建了兩個(gè)listen comm,其中rank0使用extHandleListenRoot進(jìn)行通信,其他節(jié)點(diǎn)之間通過extHandleListen進(jìn)行通信。
然后再回去接著看bootstrapInit
ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) {
// get info on my "next" rank in the bootstrap ring from root
ncclNetHandle_t extHandleNext;
NCCLCHECK(bootstrapNetAccept(extBstrapListenCommRoot, &tmpRecvComm));
NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &extHandleNext, sizeof(extHandleNext)));
NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));
NCCLCHECK(bootstrapNetCloseListen(extBstrapListenCommRoot));
NCCLCHECK(bootstrapNetConnect(state->dev, &extHandleNext, &state->extBstrapRingSendComm));
// Accept the connect request from the previous rank in the AllGather ring
NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &state->extBstrapRingRecvComm));
// AllGather all listen handlers
NCCLCHECK(ncclCalloc(&state->peerBstrapHandles, nranks));
memcpy(state->peerBstrapHandles+rank, info.extHandleListen, sizeof(ncclNetHandle_t));
NCCLCHECK(bootstrapAllGather(state, state->peerBstrapHandles, sizeof(ncclNetHandle_t)));
TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);
return ncclSuccess;
}
接著所有rank都會(huì)在extHandleListenRoot上接收新連接創(chuàng)建tmpRecvComm,然后接收到當(dāng)前rank的next的ip,port;然后連接next創(chuàng)建bscomm到state->extBstrapRingSendComm,接收prev的連接創(chuàng)建bscomm到state->extBstrapRingRecvComm,到現(xiàn)在bootstrap網(wǎng)絡(luò)連接就完全建立起來了,如下圖:

最后gather所有rank的ip port
首先將自己的nethandle放到peerBstrapHandles的對(duì)應(yīng)位置,如下所示。

然后執(zhí)行bootstrapAllGather:
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
struct extState* state = (struct extState*)commState;
char* data = (char*)allData;
int rank = state->rank;
int nranks = state->nranks;
TRACE(NCCL_INIT, "rank %d nranks %d size %d", rank, nranks, size);
/* Simple ring based AllGather
* At each step i receive data from (rank-i-1) from left
* and send previous step's data from (rank-i) to right
*/
for (int i=0; i<nranks-1; i++) {
size_t rslice = (rank - i - 1 + nranks) % nranks;
size_t sslice = (rank - i + nranks) % nranks;
// Send slice to the right
NCCLCHECK(bootstrapNetSend(state->extBstrapRingSendComm, data+sslice*size, size));
// Recv slice from the left
NCCLCHECK(bootstrapNetRecv(state->extBstrapRingRecvComm, data+rslice*size, size));
}
TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
return ncclSuccess;
}
每一次將自己的data發(fā)送給對(duì)應(yīng)的rank,然后接收其他rank發(fā)送過來的data,如下圖。
第一步:

第二步:

到這里每個(gè)rank就都有了全局所有rank的ip port。
最后總結(jié)一下,本節(jié)主要?jiǎng)?chuàng)建了bootstrap環(huán)形網(wǎng)絡(luò)連接,并保存到ncclComm里。
歡迎 Star、試用 OneFlow 最新版本:github.com/Oneflow-Inc…
以上就是NCCL深度學(xué)習(xí)Bootstrap網(wǎng)絡(luò)連接建立源碼解析的詳細(xì)內(nèi)容,更多關(guān)于NCCL Bootstrap網(wǎng)絡(luò)連接的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Django項(xiàng)目在pycharm新建的步驟方法
在本篇文章里小編給大家整理的是一篇關(guān)于Django項(xiàng)目在pycharm新建的步驟方法,有興趣的朋友們可以學(xué)習(xí)參考下。2021-03-03
Python List cmp()知識(shí)點(diǎn)總結(jié)
在本篇內(nèi)容里小編給大家整理了關(guān)于Python List cmp()用法相關(guān)知識(shí)點(diǎn),有需要的朋友們跟著學(xué)習(xí)下。2019-02-02
Python利用jmespath模塊進(jìn)行json數(shù)據(jù)處理
jmespath是python的第三方模塊,是需要額外安裝的。它在python原有的json數(shù)據(jù)處理上做出了很大的貢獻(xiàn)。本文將詳細(xì)介紹如何利用jmespath實(shí)現(xiàn)json數(shù)據(jù)處理,需要的可以參考一下2022-03-03
Python?torch.onnx.export用法詳細(xì)介紹
這篇文章主要給大家介紹了關(guān)于Python?torch.onnx.export用法詳細(xì)介紹的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2022-07-07
在 Windows 下搭建高效的 django 開發(fā)環(huán)境的詳細(xì)教程
這篇文章主要介紹了如何在 Windows 下搭建高效的 django 開發(fā)環(huán)境,本文通過一篇詳細(xì)教程實(shí)例代碼相結(jié)合給大家講解的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-07-07
Jupyter notebook如何實(shí)現(xiàn)打開數(shù)據(jù)集
這篇文章主要介紹了Jupyter notebook如何實(shí)現(xiàn)打開數(shù)據(jù)集問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-01-01
Python腳本實(shí)現(xiàn)一鍵自動(dòng)整理辦公文件
這篇文章主要介紹了Python實(shí)現(xiàn)腳本一鍵自動(dòng)整理辦公文件,文件下載文件夾就變得亂七八糟,整理的時(shí)候非常痛苦,巴不得有一個(gè)自動(dòng)化的工具幫我歸類文檔。下面小編就給大家分享自動(dòng)化整理文件的小技巧,需要的朋友可以參考一下文章內(nèi)容2022-02-02
Numpy實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)(CNN)的示例
這篇文章主要介紹了Numpy實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)(CNN)的示例,幫助大家更好的理解和使用Numpy,感興趣的朋友可以了解下2020-10-10
終于明白tf.reduce_sum()函數(shù)和tf.reduce_mean()函數(shù)用法
這篇文章主要介紹了終于明白tf.reduce_sum()函數(shù)和tf.reduce_mean()函數(shù)用法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-11-11

