使用Pytorch訓(xùn)練分類問題時(shí),分類準(zhǔn)確率的計(jì)算方式
Pytorch訓(xùn)練分類問題時(shí),分類準(zhǔn)確率的計(jì)算
作者記錄方便查詢
使用條件
真實(shí)標(biāo)簽與預(yù)測(cè)標(biāo)簽都是tensor。
使用方法
#標(biāo)簽情況 print(y) tensor([[1, 1, 0, 0]]) print(pred) tensor([[1, 0, 1, 0]]) # 比較真實(shí)與預(yù)測(cè) print(y==pred) tensor([[ True, False, False, True]]) # 對(duì)正確元素求和,sum會(huì)自動(dòng)計(jì)算True的個(gè)數(shù) print((y==pred).sum()) tensor(2)
因此在每個(gè)epoch開始時(shí),只需要初始化一個(gè)計(jì)數(shù)器accuracy,對(duì)每次的正確元素進(jìn)行累加,在除以訓(xùn)練元素的總數(shù),便獲得了每個(gè)epoch的準(zhǔn)確率。
for epoch in range(epochs): accuracy=0 for i, (x,y) in enumerate(train_loader, 1): pred = net(x) loss = loss_function(pred.to(torch.float32),y.to(torch.float32)) optimizer.zero_grad() loss.backward() #反向傳播 optimizer.step() #更新梯度 loss_steps[epoch]=loss.item()#保存loss running_loss = loss.item() accuracy += (pred == y).sum() acc = float(accuracy*100)/float(len(train_ids))# 除以元素總數(shù),可以用其他方式獲取 print(f"第{epoch}次訓(xùn)練,loss={running_loss:.4f},Accuracy={acc:.3f}".format(epoch,running_loss,acc))
結(jié)果
Pytorch 計(jì)算分類器準(zhǔn)確率(總分類及子分類)
分類器平均準(zhǔn)確率計(jì)算
correct = torch.zeros(1).squeeze().cuda() total = torch.zeros(1).squeeze().cuda() for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) correct += (prediction == labels).sum().float() total += len(labels) acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())
分類器各個(gè)子類準(zhǔn)確率計(jì)算
correct = list(0. for i in range(args.class_num)) total = list(0. for i in range(args.class_num)) for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) res = prediction == labels for label_idx in range(len(labels)): label_single = label[label_idx] correct[label_single] += res[label_idx].item() total[label_single] += 1 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total)) for acc_idx in range(len(train_class_correct)): try: acc = correct[acc_idx]/total[acc_idx] except: acc = 0 finally: acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python?pyqt5下拉多選框的實(shí)現(xiàn)示例
QComboBox是一個(gè)集按鈕和下拉選項(xiàng)于一體的控件,本文主要介紹了Python?pyqt5下拉多選框的實(shí)現(xiàn)示例,具有一定的參考價(jià)值,感興趣的可以了解一下2025-04-04解決Python下imread,imwrite不支持中文的問題
今天小編就為大家分享一篇解決Python下imread,imwrite不支持中文的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-12-12淺談Pandas dataframe數(shù)據(jù)處理方法的速度比較
這篇文章主要介紹了淺談Pandas dataframe數(shù)據(jù)處理方法的速度比較,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-04-04使用Python實(shí)現(xiàn)簡單的人臉識(shí)別功能(附源碼)
Python中實(shí)現(xiàn)人臉識(shí)別功能有多種方法,依賴于python膠水語言的特性,我們通過調(diào)用包可以快速準(zhǔn)確的達(dá)成這一目的,本文給大家分享使用Python實(shí)現(xiàn)簡單的人臉識(shí)別功能的操作步驟,感興趣的朋友一起看看吧2021-12-12