使用Pytorch訓練分類問題時,分類準確率的計算方式
更新時間:2023年09月14日 14:24:58 作者:jayus丶
這篇文章主要介紹了使用Pytorch訓練分類問題時,分類準確率的計算方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
Pytorch訓練分類問題時,分類準確率的計算
作者記錄方便查詢
使用條件
真實標簽與預測標簽都是tensor。
使用方法
#標簽情況 print(y) tensor([[1, 1, 0, 0]]) print(pred) tensor([[1, 0, 1, 0]]) # 比較真實與預測 print(y==pred) tensor([[ True, False, False, True]]) # 對正確元素求和,sum會自動計算True的個數(shù) print((y==pred).sum()) tensor(2)
因此在每個epoch開始時,只需要初始化一個計數(shù)器accuracy,對每次的正確元素進行累加,在除以訓練元素的總數(shù),便獲得了每個epoch的準確率。
for epoch in range(epochs):
accuracy=0
for i, (x,y) in enumerate(train_loader, 1):
pred = net(x)
loss = loss_function(pred.to(torch.float32),y.to(torch.float32))
optimizer.zero_grad()
loss.backward() #反向傳播
optimizer.step() #更新梯度
loss_steps[epoch]=loss.item()#保存loss
running_loss = loss.item()
accuracy += (pred == y).sum()
acc = float(accuracy*100)/float(len(train_ids))# 除以元素總數(shù),可以用其他方式獲取
print(f"第{epoch}次訓練,loss={running_loss:.4f},Accuracy={acc:.3f}".format(epoch,running_loss,acc))結果

Pytorch 計算分類器準確率(總分類及子分類)
分類器平均準確率計算
correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
output = model(images)
prediction = torch.argmax(output, 1)
correct += (prediction == labels).sum().float()
total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())分類器各個子類準確率計算
correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.cuda())
labels = Variable(labels.cuda())
output = model(images)
prediction = torch.argmax(output, 1)
res = prediction == labels
for label_idx in range(len(labels)):
label_single = label[label_idx]
correct[label_single] += res[label_idx].item()
total[label_single] += 1
acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
for acc_idx in range(len(train_class_correct)):
try:
acc = correct[acc_idx]/total[acc_idx]
except:
acc = 0
finally:
acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
解決Python下imread,imwrite不支持中文的問題
今天小編就為大家分享一篇解決Python下imread,imwrite不支持中文的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12
淺談Pandas dataframe數(shù)據(jù)處理方法的速度比較
這篇文章主要介紹了淺談Pandas dataframe數(shù)據(jù)處理方法的速度比較,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-04-04
使用Python實現(xiàn)簡單的人臉識別功能(附源碼)
Python中實現(xiàn)人臉識別功能有多種方法,依賴于python膠水語言的特性,我們通過調用包可以快速準確的達成這一目的,本文給大家分享使用Python實現(xiàn)簡單的人臉識別功能的操作步驟,感興趣的朋友一起看看吧2021-12-12

