keras處理欠擬合和過擬合的實例講解
baseline
import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(16, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()
baseline_history = baseline_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
小模型
small_model = keras.Sequential(
[
layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(4, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
大模型
big_model = keras.Sequential(
[
layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(512, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
繪圖比較上述三個模型
def plot_history(histories, key='binary_crossentropy'):
plt.figure(figsize=(16,10))
for name, history in histories:
val = plt.plot(history.epoch, history.history['val_'+key],
'--', label=name.title()+' Val')
plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
label=name.title()+' Train')
plt.xlabel('Epochs')
plt.ylabel(key.replace('_',' ').title())
plt.legend()
plt.xlim([0,max(history.epoch)])
plot_history([('baseline', baseline_history),
('small', small_history),
('big', big_history)])

三個模型在迭代過程中在訓(xùn)練集的表現(xiàn)都會越來越好,并且都會出現(xiàn)過擬合的現(xiàn)象
大模型在訓(xùn)練集上表現(xiàn)更好,過擬合的速度更快
l2正則減少過擬合
l2_model = keras.Sequential(
[
layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
plot_history([('baseline', baseline_history),
('l2', l2_history)])

可以發(fā)現(xiàn)正則化之后的模型在驗證集上的過擬合程度減少
添加dropout減少過擬合
dpt_model = keras.Sequential(
[
layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dropout(0.5),
layers.Dense(16, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
plot_history([('baseline', baseline_history),
('dropout', dpt_history)])

批正則化
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.BatchNormalization(),
layers.Dense(64, activation='relu'),
layers.BatchNormalization(),
layers.Dense(64, activation='relu'),
layers.BatchNormalization(),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()
總結(jié)
防止神經(jīng)網(wǎng)絡(luò)中過度擬合的最常用方法:
獲取更多訓(xùn)練數(shù)據(jù)。
減少網(wǎng)絡(luò)容量。
添加權(quán)重正規(guī)化。
添加dropout。
以上這篇keras處理欠擬合和過擬合的實例講解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django 創(chuàng)建后臺,配置sqlite3教程
今天小編就為大家分享一篇Django 創(chuàng)建后臺,配置sqlite3教程,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-11-11
python3 自動識別usb連接狀態(tài),即對usb重連的判斷方法
今天小編就為大家分享一篇python3 自動識別usb連接狀態(tài),即對usb重連的判斷方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07
python paramiko連接ssh實現(xiàn)命令
這篇文章主要為大家介紹了python paramiko連接ssh實現(xiàn)的命令詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2022-07-07

