欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

基于BCEWithLogitsLoss樣本不均衡的處理方案

 更新時間:2021年05月13日 10:52:49   作者:ucas_fhx  
這篇文章主要介紹了BCEWithLogitsLoss樣本不均衡的處理方案,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

最近在做deepfake檢測任務(wù)(可以將其視為二分類問題,label為1和0),遇到了正負(fù)樣本不均衡的問題,正樣本數(shù)目是負(fù)樣本的5倍,這樣會導(dǎo)致FP率較高。

嘗試將正樣本的loss權(quán)重增高,看BCEWithLogitsLoss的源碼

Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)
 
Args:
    weight (Tensor, optional): a manual rescaling weight given to the loss
        of each batch element. If given, has to be a Tensor of size `nbatch`.
    size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
        the losses are averaged over each loss element in the batch. Note that for
        some losses, there are multiple elements per sample. If the field :attr:`size_average`
        is set to ``False``, the losses are instead summed for each minibatch. Ignored
        when reduce is ``False``. Default: ``True``
    reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
        losses are averaged or summed over observations for each minibatch depending
        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
        batch element instead and ignores :attr:`size_average`. Default: ``True``
    reduction (string, optional): Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        ``'mean'``: the sum of the output will be divided by the number of
        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
    pos_weight (Tensor, optional): a weight of positive examples.
            Must be a vector with length equal to the number of classes.

對其中的參數(shù)pos_weight的使用存在疑惑,BCEloss里的例子pos_weight = torch.ones([64]) # All weights are equal to 1,不懂為什么會有64個class,因?yàn)锽CEloss是針對二分類問題的loss,后經(jīng)過檢索,得知還有多標(biāo)簽分類,

多標(biāo)簽分類就是多個標(biāo)簽,每個標(biāo)簽有兩個label(0和1),這類任務(wù)同樣可以使用BCEloss。

現(xiàn)在講一下BCEWithLogitsLoss里的pos_weight使用方法

比如我們有正負(fù)兩類樣本,正樣本數(shù)量為100個,負(fù)樣本為400個,我們想要對正負(fù)樣本的loss進(jìn)行加權(quán)處理,將正樣本的loss權(quán)重放大4倍,通過這樣的方式緩解樣本不均衡問題。

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
 
# pos_weight (Tensor, optional): a weight of positive examples.
#            Must be a vector with length equal to the number of classes.

pos_weight里是一個tensor列表,需要和標(biāo)簽個數(shù)相同,比如我們現(xiàn)在是二分類,只需要將正樣本loss的權(quán)重寫上即可。

如果是多標(biāo)簽分類,有64個標(biāo)簽,則

Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)

補(bǔ)充:Pytorch —— BCEWithLogitsLoss()的一些問題

一、等價表達(dá)

1、pytorch:

torch.sigmoid() + torch.nn.BCELoss()

2、自己編寫

def ce_loss(y_pred, y_train, alpha=1):
    
    p = torch.sigmoid(y_pred)
    # p = torch.clamp(p, min=1e-9, max=0.99)  
    loss = torch.sum(- alpha * torch.log(p) * y_train \
           - torch.log(1 - p) * (1 - y_train))/len(y_train)
    return loss~

3、驗(yàn)證

import torch
import torch.nn as nn
torch.cuda.manual_seed(300)       # 為當(dāng)前GPU設(shè)置隨機(jī)種子
torch.manual_seed(300)            # 為CPU設(shè)置隨機(jī)種子
def ce_loss(y_pred, y_train, alpha=1):
   # 計算loss
   p = torch.sigmoid(y_pred)
   # p = torch.clamp(p, min=1e-9, max=0.99)
   loss = torch.sum(- alpha * torch.log(p) * y_train \
          - torch.log(1 - p) * (1 - y_train))/len(y_train)
   return loss
py_lossFun = nn.BCEWithLogitsLoss()
input = torch.randn((10000,1), requires_grad=True)
target = torch.ones((10000,1))
target.requires_grad_(True)
py_loss = py_lossFun(input, target)
py_loss.backward()
print("*********BCEWithLogitsLoss***********")
print("loss: ")
print(py_loss.item())
print("梯度: ")
print(input.grad)
input = input.detach()
input.requires_grad_(True)
self_loss = ce_loss(input, target)
self_loss.backward()
print("*********SelfCELoss***********")
print("loss: ")
print(self_loss.item())
print("梯度: ")
print(input.grad)

測試結(jié)果:

在這里插入圖片描述

– 由上結(jié)果可知,我編寫的loss和pytorch中提供的j基本一致。

– 但是僅僅這樣就可以了嗎?NO! 下面介紹BCEWithLogitsLoss()的強(qiáng)大之處:

– BCEWithLogitsLoss()具有很好的對nan的處理能力,對于我寫的代碼(四層神經(jīng)網(wǎng)絡(luò),層之間的激活函數(shù)采用的是ReLU,輸出層激活函數(shù)采用sigmoid(),由于數(shù)據(jù)處理的問題,所以會導(dǎo)致我們編寫的CE的loss出現(xiàn)nan:原因如下:

–首先神經(jīng)網(wǎng)絡(luò)輸出的pre_target較大,就會導(dǎo)致sigmoid之后的p為1,則torch.log(1 - p)為nan;

– 使用clamp(函數(shù)雖然會解除這個nan,但是由于在迭代過程中,網(wǎng)絡(luò)輸出可能越來越大(層之間使用的是ReLU),則導(dǎo)致我們寫的loss陷入到某一個數(shù)值而無法進(jìn)行優(yōu)化。但是BCEWithLogitsLoss()對這種情況下出現(xiàn)的nan有很好的處理,從而得到更好的結(jié)果。

– 我此實(shí)驗(yàn)的目的是為了比較CE和FL的區(qū)別,自己編寫FL,則必須也要自己編寫CE,不能使用BCEWithLogitsLoss()。

二、使用場景

二分類 + sigmoid()

使用sigmoid作為輸出層非線性表達(dá)的分類問題(雖然可以處理多分類問題,但是一般用于二分類,并且最后一層只放一個節(jié)點(diǎn))

三、注意事項(xiàng)

輸入格式

要求輸入的input和target均為float類型

以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python 探針的實(shí)現(xiàn)原理

    Python 探針的實(shí)現(xiàn)原理

    本文將簡單講述一下 Python 探針的實(shí)現(xiàn)原理。 同時為了驗(yàn)證這個原理,我們也會一起來實(shí)現(xiàn)一個簡單的統(tǒng)計指定函數(shù)執(zhí)行時間的探針程序。
    2016-04-04
  • python scipy卷積運(yùn)算的實(shí)現(xiàn)方法

    python scipy卷積運(yùn)算的實(shí)現(xiàn)方法

    這篇文章主要介紹了python scipy卷積運(yùn)算的實(shí)現(xiàn)方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-09-09
  • python爬蟲超時的處理的實(shí)例

    python爬蟲超時的處理的實(shí)例

    今天小編就為大家分享一篇python爬蟲超時的處理的實(shí)例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12
  • 把大數(shù)據(jù)數(shù)字口語化(python與js)兩種實(shí)現(xiàn)

    把大數(shù)據(jù)數(shù)字口語化(python與js)兩種實(shí)現(xiàn)

    當(dāng)出現(xiàn)萬以上的整型數(shù)字時,經(jīng)常要把它們口語化比較直觀。下面分享兩段代碼,python與js的
    2013-02-02
  • Pycharm創(chuàng)建python文件自動添加日期作者等信息(步驟詳解)

    Pycharm創(chuàng)建python文件自動添加日期作者等信息(步驟詳解)

    這篇文章主要介紹了Pycharm創(chuàng)建python文件自動添加日期作者等信息(步驟詳解),本文分步驟給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2021-02-02
  • python爬蟲lxml庫解析xpath網(wǎng)頁過程示例

    python爬蟲lxml庫解析xpath網(wǎng)頁過程示例

    這篇文章主要為大家介紹了python爬蟲lxml庫解析xpath網(wǎng)頁的過程示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-05-05
  • Python實(shí)現(xiàn)修改文件內(nèi)容的方法分析

    Python實(shí)現(xiàn)修改文件內(nèi)容的方法分析

    這篇文章主要介紹了Python實(shí)現(xiàn)修改文件內(nèi)容的方法,結(jié)合實(shí)例形式分析了Python文件讀寫、字符串替換及shell方法調(diào)用等相關(guān)操作技巧,需要的朋友可以參考下
    2018-03-03
  • python中join()方法介紹

    python中join()方法介紹

    Python join() 方法用于將序列中的元素以指定的字符連接生成一個新的字符串。這篇文章主要介紹了python中join()方法,需要的朋友可以參考下
    2018-10-10
  • python添加菜單圖文講解

    python添加菜單圖文講解

    在本篇文章中小編給大家整理的是關(guān)于python添加菜單圖文講解以及步驟分析,需要的朋友們學(xué)習(xí)下吧。
    2019-06-06
  • flask框架視圖函數(shù)用法示例

    flask框架視圖函數(shù)用法示例

    這篇文章主要介紹了flask框架視圖函數(shù)用法,結(jié)合實(shí)例形式分析了flask框架視圖函數(shù)常見配置與使用技巧,需要的朋友可以參考下
    2018-07-07

最新評論