欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch實(shí)現(xiàn)對(duì)輸入超過三通道的數(shù)據(jù)進(jìn)行訓(xùn)練

 更新時(shí)間:2020年01月15日 10:06:51   作者:東城青年  
今天小編就為大家分享一篇pytorch實(shí)現(xiàn)對(duì)輸入超過三通道的數(shù)據(jù)進(jìn)行訓(xùn)練,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧

案例背景:視頻識(shí)別

假設(shè)每次輸入是8s的灰度視頻,視頻幀率為25fps,則視頻由200幀圖像序列構(gòu)成.每幀是一副單通道的灰度圖像,通過pythonb里面的np.stack(深度拼接)可將200幀拼接成200通道的深度數(shù)據(jù).進(jìn)而送到網(wǎng)絡(luò)里面去訓(xùn)練.

如果輸入圖像200通道覺得多,可以對(duì)視頻進(jìn)行抽幀,針對(duì)具體場(chǎng)景可以隨機(jī)抽幀或等間隔抽幀.比如這里等間隔抽取40幀.則最后輸入視頻相當(dāng)于輸入一個(gè)40通道的圖像數(shù)據(jù)了.

pytorch對(duì)超過三通道數(shù)據(jù)的加載:

讀取視頻每一幀,轉(zhuǎn)為array格式,然后依次將每一幀進(jìn)行深度拼接,最后得到一個(gè)40通道的array格式的深度數(shù)據(jù),保存到pickle里.

對(duì)每個(gè)視頻都進(jìn)行上述操作,保存到pickle里.

我這里將火的視頻深度數(shù)據(jù)保存在一個(gè).pkl文件中,一共2504個(gè)火的視頻,即2504個(gè)火的深度數(shù)據(jù).

將非火的視頻深度數(shù)據(jù)保存在一個(gè).pkl文件中,一共3985個(gè)非火的視頻,即3985個(gè)非火的深度數(shù)據(jù).

數(shù)據(jù)加載

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*寬*通道
      fire = fire.transpose(2,0,1)#通道*高*寬
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#轉(zhuǎn)換成pytorch網(wǎng)絡(luò)輸入的格式(批量大小,通道數(shù),高,寬)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型訓(xùn)練

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定義
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #損失函數(shù)
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #學(xué)習(xí)率
  lr = opt.lr
  
  #優(yōu)化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #訓(xùn)練
  for epoch in range(opt.max_epoch):
     #訓(xùn)練數(shù)據(jù)
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #輸入
      input = V(datas.cuda()).float()
      #目標(biāo)
      target = V(labels.cuda()).long()
      #輸出
      score = model(input).cuda()
      #損失
      loss = criterion(score,target)
      loss_sum += loss
      #反向傳播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解決方案:target = target.long()

以上這篇pytorch實(shí)現(xiàn)對(duì)輸入超過三通道的數(shù)據(jù)進(jìn)行訓(xùn)練就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 局域網(wǎng)內(nèi)python socket實(shí)現(xiàn)windows與linux間的消息傳送

    局域網(wǎng)內(nèi)python socket實(shí)現(xiàn)windows與linux間的消息傳送

    這篇文章主要介紹了局域網(wǎng)內(nèi)python socket實(shí)現(xiàn)windows與linux間的消息傳送的相關(guān)知識(shí),非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2019-04-04
  • python實(shí)現(xiàn)人工蜂群算法

    python實(shí)現(xiàn)人工蜂群算法

    這篇文章主要介紹了python如何實(shí)現(xiàn)人工蜂群算法,幫助大家更好的利用python進(jìn)行數(shù)據(jù)分析,感興趣的朋友可以了解下
    2020-09-09
  • Python線程編程之Thread詳解

    Python線程編程之Thread詳解

    這篇文章主要為大家介紹了Python線程編程之Thread,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助
    2021-12-12
  • 關(guān)于keras.layers.Conv1D的kernel_size參數(shù)使用介紹

    關(guān)于keras.layers.Conv1D的kernel_size參數(shù)使用介紹

    這篇文章主要介紹了關(guān)于keras.layers.Conv1D的kernel_size參數(shù)使用介紹,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • python自定義解析簡單xml格式文件的方法

    python自定義解析簡單xml格式文件的方法

    這篇文章主要介紹了python自定義解析簡單xml格式文件的方法,涉及Python解析XML文件的相關(guān)技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下
    2015-05-05
  • Python腳本處理空格的方法

    Python腳本處理空格的方法

    這篇文章主要介紹了Python腳本處理空格的方法,解決方案非常簡單,但是好多朋友都不知道,下面小編把解決方案分享到腳本之家平臺(tái),供大家參考
    2016-08-08
  • 徹底解決Python包下載慢問題

    徹底解決Python包下載慢問題

    這篇文章主要介紹了徹底解決Python包下載慢問題,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-11-11
  • Django實(shí)現(xiàn)任意文件上傳(最簡單的方法)

    Django實(shí)現(xiàn)任意文件上傳(最簡單的方法)

    這篇文章主要介紹了Django實(shí)現(xiàn)任意文件上傳,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-06-06
  • 如何從Python的cmd中獲得.py文件參數(shù)

    如何從Python的cmd中獲得.py文件參數(shù)

    這篇文章主要介紹了如何從Python的cmd中獲得.py文件參數(shù)操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2021-05-05
  • python __init__與 __new__的區(qū)別

    python __init__與 __new__的區(qū)別

    本文主要介紹了python __init__與 __new__的區(qū)別,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02

最新評(píng)論