深度學(xué)習(xí)小工程練習(xí)之tensorflow垃圾分類詳解
介紹
這是一個(gè)基于深度學(xué)習(xí)的垃圾分類小工程,用深度殘差網(wǎng)絡(luò)構(gòu)建
軟件架構(gòu)
- 使用深度殘差網(wǎng)絡(luò)resnet50作為基石,在后續(xù)添加需要的層以適應(yīng)不同的分類任務(wù)
- 模型的訓(xùn)練需要用生成器將數(shù)據(jù)集循環(huán)寫入內(nèi)存,同時(shí)圖像增強(qiáng)以泛化模型
- 使用不包含網(wǎng)絡(luò)輸出部分的resnet50權(quán)重文件進(jìn)行遷移學(xué)習(xí),只訓(xùn)練我們?cè)?個(gè)stage后增加的層
安裝教程
- 需要的第三方庫(kù)主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
- 安裝方式很簡(jiǎn)單,打開terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
- 數(shù)據(jù)集與權(quán)重文件比較大,所以沒(méi)有上傳
- 如果環(huán)境配置方面有問(wèn)題或者需要數(shù)據(jù)集與模型權(quán)重文件,可以在評(píng)論區(qū)說(shuō)明您的問(wèn)題,我將遠(yuǎn)程幫助您
使用說(shuō)明
- 文件夾theory記錄了我在本次深度學(xué)習(xí)中收獲的筆記,與模型訓(xùn)練的控制臺(tái)打印信息
- 遷移學(xué)習(xí)需要的初始權(quán)重與模型定義文件resnet50.py放在model
- 下訓(xùn)練運(yùn)行trainNet.py,訓(xùn)練結(jié)束會(huì)創(chuàng)建models文件夾,并將結(jié)果權(quán)重garclass.h5寫入該文件夾
- datagen文件夾下的genit.py用于進(jìn)行圖像預(yù)處理以及數(shù)據(jù)生成器接口
- 使用訓(xùn)練好的模型進(jìn)行垃圾分類,運(yùn)行Demo.py
結(jié)果演示
cans易拉罐
代碼解釋
在實(shí)際的模型中,我們只使用了resnet50的5個(gè)stage,后面的輸出部分需要我們自己定制,網(wǎng)絡(luò)的結(jié)構(gòu)圖如下:
stage5后我們的定制網(wǎng)絡(luò)如下:
"""定制resnet后面的層""" def custom(input_size,num_classes,pretrain): # 引入初始化resnet50模型 base_model = ResNet50(weights=pretrain, include_top=False, pooling=None, input_shape=(input_size,input_size, 3), classes=num_classes) #由于有預(yù)權(quán)重,前部分凍結(jié),后面進(jìn)行遷移學(xué)習(xí) for layer in base_model.layers: layer.trainable = False #添加后面的層 x = base_model.output x = layers.GlobalAveragePooling2D(name='avg_pool')(x) x = layers.Dropout(0.5,name='dropout1')(x) #regularizers正則化層,正則化器允許在優(yōu)化過(guò)程中對(duì)層的參數(shù)或?qū)拥募せ钋闆r進(jìn)行懲罰 #對(duì)損失函數(shù)進(jìn)行最小化的同時(shí),也需要讓對(duì)參數(shù)添加限制,這個(gè)限制也就是正則化懲罰項(xiàng),使用l2范數(shù) x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x) x = layers.BatchNormalization(name='bn_fc_01')(x) x = layers.Dropout(0.5,name='dropout2')(x) #40個(gè)分類 x = layers.Dense(num_classes,activation='softmax')(x) model = Model(inputs=base_model.input,outputs=x) #模型編譯 model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy']) return model
網(wǎng)絡(luò)的訓(xùn)練是遷移學(xué)習(xí)過(guò)程,使用已有的初始resnet50權(quán)重(5個(gè)stage已經(jīng)訓(xùn)練過(guò),卷積層已經(jīng)能夠提取特征),我們只訓(xùn)練后面的全連接層部分,4個(gè)epoch后再對(duì)較后面的層進(jìn)行訓(xùn)練微調(diào)一下,獲得更高準(zhǔn)確率,訓(xùn)練過(guò)程如下:
class Net(): def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain): self.img_size=img_size self.gar_num=gar_num self.data_dir=data_dir self.batch_size=batch_size self.pretrain=pretrain def build_train(self): """遷移學(xué)習(xí)""" model = resnet.custom(self.img_size, self.gar_num, self.pretrain) model.summary() train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size) epochs=4 model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence, max_queue_size=10,shuffle=True) #微調(diào),在實(shí)際工程中,激活函數(shù)也被算進(jìn)層里,所以總共181層,微調(diào)是為了重新訓(xùn)練部分卷積層,同時(shí)訓(xùn)練最后的全連接層 layers=149 learning_rate=1e-4 for layer in model.layers[:layers]: layer.trainable = False for layer in model.layers[layers:]: layer.trainable = True Adam =adam(lr=learning_rate, decay=0.0005) model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy']) model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1, callbacks=[ callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'), callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'), callbacks.EarlyStopping(monitor='val_loss', patience=10),], validation_data=validation_sequence,max_queue_size=10,shuffle=True) print('finish train,look for garclass.h5')
訓(xùn)練結(jié)果如下:
"""
loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
訓(xùn)練用了9小時(shí)左右
"""
如果使用更好的顯卡,可以更快完成訓(xùn)練
最后
希望大家可以體驗(yàn)到深度學(xué)習(xí)帶來(lái)的收獲,能和大家學(xué)習(xí)很開心,更多關(guān)于深度學(xué)習(xí)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python爬取百度翻譯實(shí)現(xiàn)中英互譯功能
這篇文章主要介紹了利用Python爬蟲爬取百度翻譯,從而實(shí)現(xiàn)中英文互譯的功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2022-01-01Python基于stuck實(shí)現(xiàn)scoket文件傳輸
這篇文章主要介紹了Python基于stuck實(shí)現(xiàn)scoket文件傳輸,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04簡(jiǎn)單學(xué)習(xí)Python time模塊
這篇文章主要和大家一起簡(jiǎn)單學(xué)習(xí)一下Python time模塊,Python time模塊提供了一些用于管理時(shí)間和日期的C庫(kù)函數(shù),對(duì)time模塊感興趣的小伙伴們可以參考一下2016-04-04