解決Keras的自定義lambda層去reshape張量時(shí)model保存出錯(cuò)問(wèn)題
前幾天忙著參加一個(gè)AI Challenger比賽,一直沒(méi)有更新博客,忙了將近一個(gè)月的時(shí)間,也沒(méi)有取得很好的成績(jī),不過(guò)這這段時(shí)間內(nèi)的確學(xué)到了很多,就在決賽結(jié)束的前一天晚上,準(zhǔn)備復(fù)現(xiàn)使用一個(gè)新的網(wǎng)絡(luò)UPerNet的時(shí)候出現(xiàn)了一個(gè)很匪夷所思,莫名其妙的一個(gè)問(wèn)題。谷歌很久都沒(méi)有解決,最后在一個(gè)日語(yǔ)網(wǎng)站上看到了解決方法。
事后想想,這個(gè)問(wèn)題在后面搭建網(wǎng)絡(luò)的時(shí)候會(huì)很常見(jiàn),但是網(wǎng)上卻沒(méi)有人提出解決辦法,So, I think that's very necessary for me to note this.
背景
分割網(wǎng)絡(luò)在進(jìn)行上采樣的時(shí)候我用的是雙線性插值上采樣的,而Keras里面并沒(méi)有實(shí)現(xiàn)雙線性插值的函數(shù),所以要自己調(diào)用tensorflow里面的tf.image.resize_bilinear()函數(shù)來(lái)進(jìn)行resize,如果直接用tf.image.resize_bilinear()函數(shù)對(duì)Keras張量進(jìn)行resize的話,會(huì)報(bào)出異常,大概意思是tenorflow張量不能轉(zhuǎn)換為Keras張量,要想將Kears Tensor轉(zhuǎn)換為 Tensorflow Tensor需要進(jìn)行自定義層,Keras自定義層的時(shí)候需要用到Lambda層來(lái)包裝。
大概源碼(只是大概意思)如下:
from keras.layers import Lambda import tensorflow as tf first_layer=Input(batch_shape=(None, 64, 32, 3)) f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x) upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3])) f=upsample_bilinear(f)
然后編譯 這個(gè)源碼:
optimizer = SGD(lr=0.01, momentum=0.9) model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy']) model.save('model.hdf5')
其中要注意到這個(gè)tf.image.resize_bilinear()里面的size,我用的是根據(jù)張量(first_layer)的形狀來(lái)做為reshape后的形狀,保存模型用的是model.save().然后就會(huì)出現(xiàn)以下錯(cuò)誤!
異常描述:
在一個(gè)epoch完成后保存model時(shí)出現(xiàn)下面錯(cuò)誤,五個(gè)錯(cuò)誤提示隨機(jī)出現(xiàn):
TypeError: cannot serialize ‘_io.TextIOWrapper' object
TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()
AttributeError: ‘NoneType' object has no attribute ‘update'
TypeError: cannot deepcopy this pattern object
TypeError: can't pickle module objects
問(wèn)題分析:
這個(gè)有兩方面原因:
tf.image.resize_bilinear()中的size不應(yīng)該用另一個(gè)張量的size去指定。
如果用了另一個(gè)張量去指定size,用model.save()來(lái)保存model是不能序列化的。那么保存model的時(shí)候只能保存權(quán)重——model.save_weights('mode_weights.hdf5')
解決辦法(兩種):
1.tf.image.resize_bilinear()的size用常數(shù)去指定
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))
2.如果用了另一個(gè)張量去指定size,那么就修改保存模型的函數(shù),變成只保存權(quán)重
model.save_weights('model_weights.hdf5')
總結(jié):
我想使用keras的Lambda層去reshape一個(gè)張量
如果為重塑形狀指定了張量,則保存模型(保存)將失敗
您可以使用save_weights而不是save進(jìn)行保存
補(bǔ)充知識(shí):Keras 添加一個(gè)自定義的loss層(output及compile中,輸出及l(fā)oss的表示方法)
例如:
計(jì)算兩個(gè)層之間的距離,作為一個(gè)loss
distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))
這是添加的一個(gè)loss層,這個(gè)distance就直接作為loss
model=Model(input=[,,,], output=[distance])
model.compile(....., loss=lambda y_true, y_pred: ypred)
以上這篇解決Keras的自定義lambda層去reshape張量時(shí)model保存出錯(cuò)問(wèn)題就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python數(shù)據(jù)清洗&預(yù)處理入門(mén)教程
凡事預(yù)則立,不預(yù)則廢,訓(xùn)練機(jī)器學(xué)習(xí)模型也是如此。數(shù)據(jù)清洗和預(yù)處理是模型訓(xùn)練之前的必要過(guò)程,否則模型可能就廢了。本文是一個(gè)初學(xué)者指南,將帶你領(lǐng)略如何在任意的數(shù)據(jù)集上,針對(duì)任意一個(gè)機(jī)器學(xué)習(xí)模型,完成數(shù)據(jù)預(yù)處理工作2022-10-10tensorflow實(shí)現(xiàn)簡(jiǎn)單的卷積網(wǎng)絡(luò)
這篇文章主要為大家詳細(xì)介紹了tensorflow實(shí)現(xiàn)簡(jiǎn)單的卷積網(wǎng)絡(luò),使用的數(shù)據(jù)集是MNIST,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-05-05使用Python pandas讀取CSV文件應(yīng)該注意什么?
本文是給使用pandas的新手而寫(xiě),主要列出一些常見(jiàn)的問(wèn)題,根據(jù)筆者所踩過(guò)的坑,進(jìn)行歸納總結(jié),希望對(duì)讀者有所幫助,需要的朋友可以參考下2021-06-06Python實(shí)現(xiàn)亂序文件重新命名編號(hào)
這篇文章主要為大家詳細(xì)介紹一下Python的一個(gè)神操作,那就是實(shí)現(xiàn)亂序文件重新命名編號(hào)功能,文中的示例代碼講解詳細(xì),感興趣的可以嘗試一下2022-08-08Python Sqlalchemy如何實(shí)現(xiàn)select for update
這篇文章主要介紹了Python Sqlalchemy如何實(shí)現(xiàn)select for update,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10對(duì)python實(shí)現(xiàn)合并兩個(gè)排序鏈表的方法詳解
今天小編就為大家分享一篇對(duì)python實(shí)現(xiàn)合并兩個(gè)排序鏈表的方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01Python實(shí)現(xiàn)對(duì)特定列表進(jìn)行從小到大排序操作示例
這篇文章主要介紹了Python實(shí)現(xiàn)對(duì)特定列表進(jìn)行從小到大排序操作,涉及Python文件讀取、計(jì)算、正則匹配、排序等相關(guān)操作技巧,需要的朋友可以參考下2019-02-02python詳解如何通過(guò)sshtunnel pymssql實(shí)現(xiàn)遠(yuǎn)程連接數(shù)據(jù)庫(kù)
為了安全起見(jiàn),很多公司服務(wù)器數(shù)據(jù)庫(kù)的訪問(wèn)多半是要做限制的,由專門(mén)的DBA管理,而且都是做的集群,數(shù)據(jù)庫(kù)只能內(nèi)網(wǎng)訪問(wèn),所以就有一個(gè)直接的問(wèn)題是,往往多數(shù)時(shí)候,在別的機(jī)器上(比如自己本地),是不能訪問(wèn)數(shù)據(jù)庫(kù)的,給日常開(kāi)發(fā)調(diào)試造成了很大不便2021-10-10Django實(shí)現(xiàn)自定義404,500頁(yè)面教程
這篇文章主要介紹了Django實(shí)現(xiàn)自定義404,500頁(yè)面的詳細(xì)方法,非常簡(jiǎn)單實(shí)用,有需要的小伙伴可以參考下2017-03-03