欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

使用tensorflow實(shí)現(xiàn)VGG網(wǎng)絡(luò),訓(xùn)練mnist數(shù)據(jù)集方式

 更新時(shí)間:2020年05月26日 11:37:40   作者:masterjames  
這篇文章主要介紹了使用tensorflow實(shí)現(xiàn)VGG網(wǎng)絡(luò),訓(xùn)練mnist數(shù)據(jù)集方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧

VGG作為流行的幾個(gè)模型之一,訓(xùn)練圖形數(shù)據(jù)效果不錯(cuò),在mnist數(shù)據(jù)集是常用的入門集數(shù)據(jù),VGG層數(shù)非常多,如果嚴(yán)格按照規(guī)范來實(shí)現(xiàn),并用來訓(xùn)練mnist數(shù)據(jù)集,會(huì)出現(xiàn)各種問題,如,經(jīng)過16層卷積后,28*28*1的圖片幾乎無法進(jìn)行。

先介紹下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman實(shí)現(xiàn)的卷積神經(jīng)網(wǎng)絡(luò),現(xiàn)在稱其為VGGNet。它主要的貢獻(xiàn)是展示出網(wǎng)絡(luò)的深度是算法優(yōu)良性能的關(guān)鍵部分。

他們最好的網(wǎng)絡(luò)包含了16個(gè)卷積/全連接層。網(wǎng)絡(luò)的結(jié)構(gòu)非常一致,從頭到尾全部使用的是3x3的卷積和2x2的匯聚。他們的預(yù)訓(xùn)練模型是可以在網(wǎng)絡(luò)上獲得并在Caffe中使用的。

VGGNet不好的一點(diǎn)是它耗費(fèi)更多計(jì)算資源,并且使用了更多的參數(shù),導(dǎo)致更多的內(nèi)存占用(140M)。其中絕大多數(shù)的參數(shù)都是來自于第一個(gè)全連接層。

模型結(jié)構(gòu):

本文在實(shí)現(xiàn)時(shí)候,盡量保存VGG原來模型結(jié)構(gòu),核心代碼如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷積實(shí)現(xiàn):

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根據(jù)模型定義還需要更多的POOL化,但mnist圖片大小不允許。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

訓(xùn)練:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 計(jì)算損失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!") 

最終效果:

訓(xùn)練10000次后:結(jié)果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

這種深度的模型可以考慮循環(huán)10萬次以上。目前效果還不錯(cuò),本人沒有GPU,心痛筆記本的CPU,100%的CPU利用率,聽到風(fēng)扇響就不忍心再訓(xùn)練,本文也借鑒了alex網(wǎng)絡(luò)實(shí)現(xiàn),當(dāng)然我也實(shí)現(xiàn)了這個(gè)網(wǎng)絡(luò)模型。在MNIST數(shù)據(jù)上,ALEX由于層數(shù)較少,收斂更快,當(dāng)然MNIST,用CNN足夠了。

以上這篇使用tensorflow實(shí)現(xiàn)VGG網(wǎng)絡(luò),訓(xùn)練mnist數(shù)據(jù)集方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python代碼如何轉(zhuǎn)jar包

    python代碼如何轉(zhuǎn)jar包

    這篇文章主要介紹了python代碼如何轉(zhuǎn)jar包問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2024-03-03
  • python2與python3的print及字符串格式化小結(jié)

    python2與python3的print及字符串格式化小結(jié)

    最近一直在用python寫程序,對(duì)于python的print一直很惱火,老是不按照預(yù)期輸出。今天特來總結(jié)一樣print和format,也希望能幫助大家徹底理解它們
    2018-11-11
  • Python實(shí)現(xiàn)抓取HTML網(wǎng)頁(yè)并以PDF文件形式保存的方法

    Python實(shí)現(xiàn)抓取HTML網(wǎng)頁(yè)并以PDF文件形式保存的方法

    這篇文章主要介紹了Python實(shí)現(xiàn)抓取HTML網(wǎng)頁(yè)并以PDF文件形式保存的方法,結(jié)合實(shí)例形式分析了PyPDF2模塊的安裝及Python抓取HTML頁(yè)面并基于PyPDF2模塊生成pdf文件的相關(guān)操作技巧,需要的朋友可以參考下
    2018-05-05
  • 如何通過Python的pyttsx3庫(kù)將文字轉(zhuǎn)為音頻

    如何通過Python的pyttsx3庫(kù)將文字轉(zhuǎn)為音頻

    pyttsx3是一個(gè)開源的Python文本轉(zhuǎn)語音庫(kù),可以將文本轉(zhuǎn)換為自然的人類語音,這篇文章主要介紹了如何通過Python的pyttsx3庫(kù)將文字轉(zhuǎn)為音頻,需要的朋友可以參考下
    2023-04-04
  • Python曲線擬合多項(xiàng)式深入詳解

    Python曲線擬合多項(xiàng)式深入詳解

    這篇文章主要給大家介紹了關(guān)于Python使用scipy進(jìn)行曲線擬合的相關(guān)資料,Scipy優(yōu)化和擬合采用的是optimize模塊,該模塊提供了函數(shù)最小值(標(biāo)量或多維)、曲線擬合和尋找等式的根的有用算法,需要的朋友可以參考下
    2022-11-11
  • Python調(diào)用ChatGPT制作基于Tkinter的桌面時(shí)鐘

    Python調(diào)用ChatGPT制作基于Tkinter的桌面時(shí)鐘

    這篇文章主要為大家詳細(xì)介紹了Python如何調(diào)用ChatGPT制作基于Tkinter的桌面時(shí)鐘,文中的示例代碼講解詳細(xì),感興趣的可以了解一下
    2023-03-03
  • 最新評(píng)論