tensorflow 固定部分參數(shù)訓(xùn)練,只訓(xùn)練部分參數(shù)的實例
在使用tensorflow來訓(xùn)練一個模型的時候,有時候需要依靠驗證集來判斷模型是否已經(jīng)過擬合,是否需要停止訓(xùn)練。
1.首先想到的是用tf.placeholder()載入不同的數(shù)據(jù)來進(jìn)行計算,比如
def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(input_) conv2 = tf.layers.conv2d(conv1) return conv2 input_ = tf.placeholder() output = inference(input_) ... calculate_loss_op = ... train_op = ... ... with tf.Session() as sess: sess.run([loss, train_op], feed_dict={input_: train_data}) if validation == True: sess.run([loss], feed_dict={input_: validate_date})
這種方式很簡單,也很直接了然。
2.但是,如果處理的數(shù)據(jù)量很大的時候,使用 tf.placeholder() 來載入數(shù)據(jù)會嚴(yán)重地拖慢訓(xùn)練的進(jìn)度,因此,常用tfrecords文件來讀取數(shù)據(jù)。
此時,很容易想到,將不同的值傳入inference()函數(shù)中進(jìn)行計算。
train_batch, label_batch = decode_train() val_train_batch, val_label_batch = decode_validation() train_result = inference(train_batch) ... loss = .. train_op = ... ... if validation == True: val_result = inference(val_train_batch) val_loss = .. with tf.Session() as sess: sess.run([loss, train_op]) if validation == True: sess.run([val_result, val_loss])
這種方式看似能夠直接調(diào)用inference()來對驗證數(shù)據(jù)進(jìn)行前向傳播計算,但是,實則會在原圖上添加上許多新的結(jié)點(diǎn),這些結(jié)點(diǎn)的參數(shù)都是需要重新初始化的,也是就是說,驗證的時候并不是使用訓(xùn)練的權(quán)重。
3.用一個tf.placeholder來控制是否訓(xùn)練、驗證。
def inference(input_): ... ... ... return inference_result train_batch, label_batch = decode_train() val_batch, val_label = decode_validation() is_training = tf.placeholder(tf.bool, shape=()) x = tf.cond(is_training, lambda: train_batch, lambda: val_batch) y = tf.cond(is_training, lambda: train_label, lambda: val_label) logits = inference(x) loss = cal_loss(logits, y) train_op = optimize(loss) with tf.Session() as sess: loss, _ = sess.run([loss, train_op], feed_dict={is_training: True}) if validation == True: loss = sess.run(loss, feed_dict={is_training: False})
使用這種方式就可以在一個大圖里創(chuàng)建一個分支條件,從而通過控制placeholder來控制是否進(jìn)行驗證。
以上這篇tensorflow 固定部分參數(shù)訓(xùn)練,只訓(xùn)練部分參數(shù)的實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python使用open函數(shù)對文件進(jìn)行處理詳解
今天看了open函數(shù),看到w+ r+ a+ 這種可讀可寫的操作,下面這篇文章主要給大家介紹了關(guān)于python使用open函數(shù)對文件進(jìn)行處理的相關(guān)資料,文中通過實例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-05-05Python3使用matplotlib繪圖時,坐標(biāo)軸刻度不從X軸、y軸兩端開始
這篇文章主要介紹了Python3使用matplotlib繪圖時,坐標(biāo)軸刻度不從X軸、y軸兩端開始問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-08-08pycharm 中mark directory as exclude的用法詳解
今天小編就為大家分享一篇pycharm 中mark directory as exclude的用法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02python神經(jīng)網(wǎng)絡(luò)特征金字塔FPN原理
這篇文章主要為大家介紹了python神經(jīng)網(wǎng)絡(luò)特征金字塔FPN原理的解釋,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05