keras自定義損失函數(shù)并且模型加載的寫法介紹
keras自定義函數(shù)時(shí)候,正常在模型里自己寫好自定義的函數(shù),然后在模型編譯的那行代碼里寫上接口即可。如下所示,focal_loss和fbeta_score是我們自己定義的兩個(gè)函數(shù),在model.compile加入它們,metrics里‘a(chǎn)ccuracy'是keras自帶的度量函數(shù)。
def focal_loss(): ... return xx def fbeta_score(): ... return yy model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )
訓(xùn)練好之后,模型加載也需要再額外加一行,通過load_model里的custom_objects將我們定義的兩個(gè)函數(shù)以字典的形式加入就能正常加載模型啦。
weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})
補(bǔ)充知識:keras如何使用自定義的loss及評價(jià)函數(shù)進(jìn)行訓(xùn)練及預(yù)測
1.有時(shí)候訓(xùn)練模型,現(xiàn)有的損失及評估函數(shù)并不足以科學(xué)的訓(xùn)練評估模型,這時(shí)候就需要自定義一些損失評估函數(shù),比如focal loss損失函數(shù)及dice評價(jià)函數(shù) for unet的訓(xùn)練。
2.在訓(xùn)練建模中導(dǎo)入自定義loss及評估函數(shù)。
#模型編譯時(shí)加入自定義loss及評估函數(shù) model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()], metrics=['accuracy',dice_coef]) #自定義loss及評估函數(shù) def binary_focal_loss(gamma=2, alpha=0.25): """ Binary form of focal loss. 適用于二分類問題的focal loss focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t) where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively. References: https://arxiv.org/pdf/1708.02002.pdf Usage: model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam) """ alpha = tf.constant(alpha, dtype=tf.float32) gamma = tf.constant(gamma, dtype=tf.float32) def binary_focal_loss_fixed(y_true, y_pred): """ y_true shape need be (None,1) y_pred need be compute after sigmoid """ y_true = tf.cast(y_true, tf.float32) alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha) p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon() focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t) return K.mean(focal_loss) return binary_focal_loss_fixed #''' #smooth 參數(shù)防止分母為0 def dice_coef(y_true, y_pred, smooth=1): intersection = K.sum(y_true * y_pred, axis=[1,2,3]) union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)
注意在模型保存時(shí),記錄的loss函數(shù)名稱:你猜是哪個(gè)
a:binary_focal_loss()
b:binary_focal_loss_fixed
3.模型預(yù)測時(shí),也要加載自定義loss及評估函數(shù),不然會(huì)報(bào)錯(cuò)。
該告訴上面的答案了,保存在模型中l(wèi)oss的名稱為:binary_focal_loss_fixed,在模型預(yù)測時(shí),定義custom_objects字典,key一定要與保存在模型中的名稱一致,不然會(huì)找不到loss function。所以自定義函數(shù)時(shí),盡量避免使用我這種函數(shù)嵌套的方式,免得帶來一些意想不到的煩惱。
model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})
以上這篇keras自定義損失函數(shù)并且模型加載的寫法介紹就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- Keras神經(jīng)網(wǎng)絡(luò)efficientnet模型搭建yolov3目標(biāo)檢測平臺
- Keras保存模型并載入模型繼續(xù)訓(xùn)練的實(shí)現(xiàn)
- Keras預(yù)訓(xùn)練的ImageNet模型實(shí)現(xiàn)分類操作
- 淺談keras使用預(yù)訓(xùn)練模型vgg16分類,損失和準(zhǔn)確度不變
- Keras 實(shí)現(xiàn)加載預(yù)訓(xùn)練模型并凍結(jié)網(wǎng)絡(luò)的層
- python神經(jīng)網(wǎng)絡(luò)Keras?GhostNet模型的實(shí)現(xiàn)
相關(guān)文章
詳解Python最長公共子串和最長公共子序列的實(shí)現(xiàn)
這篇文章主要介紹了詳解Python最長公共子串和最長公共子序列的實(shí)現(xiàn)。小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2018-07-07Python?ArcPy實(shí)現(xiàn)批量拼接長時(shí)間序列柵格圖像
這篇文章主要介紹了如何基于Python中ArcPy模塊,對大量不同時(shí)相的柵格遙感影像按照其成像時(shí)間依次執(zhí)行批量拼接的方法,感興趣的可以了解一下2023-03-03關(guān)于TensorFlow、Keras、Python版本匹配一覽表
這篇文章主要介紹了關(guān)于TensorFlow、Keras、Python版本匹配一覽表,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-03-03Python調(diào)用騰訊API實(shí)現(xiàn)人臉身份證比對功能
這篇文章主要介紹了Python調(diào)用騰訊API進(jìn)行人臉身份證比對,簡單介紹了調(diào)用騰訊云API步驟,通過完整代碼展示與結(jié)果,需要的朋友可以參考下2022-04-04python bluetooth藍(lán)牙信息獲取藍(lán)牙設(shè)備類型的方法
這篇文章主要介紹了python bluetooth藍(lán)牙信息獲取藍(lán)牙設(shè)備類型的方法,具體轉(zhuǎn)化方法文中給大家介紹的非常詳細(xì),非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-11-11從0到1使用python開發(fā)一個(gè)半自動(dòng)答題小程序的實(shí)現(xiàn)
這篇文章主要介紹了從0到1使用python開發(fā)一個(gè)半自動(dòng)答題小程序的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-05-05python實(shí)現(xiàn)根據(jù)主機(jī)名字獲得所有ip地址的方法
這篇文章主要介紹了python實(shí)現(xiàn)根據(jù)主機(jī)名字獲得所有ip地址的方法,涉及Python解析IP地址的相關(guān)技巧,需要的朋友可以參考下2015-06-06