keras-siamese用自己的數(shù)據(jù)集實(shí)現(xiàn)詳解
Siamese網(wǎng)絡(luò)不做過(guò)多介紹,思想并不難,輸入兩個(gè)圖像,輸出這兩張圖像的相似度,兩個(gè)輸入的網(wǎng)絡(luò)結(jié)構(gòu)是相同的,參數(shù)共享。
主要發(fā)現(xiàn)很多代碼都是基于mnist數(shù)據(jù)集的,下面說(shuō)一下怎么用自己的數(shù)據(jù)集實(shí)現(xiàn)siamese網(wǎng)絡(luò)。
首先,先整理數(shù)據(jù)集,相同的類(lèi)放到同一個(gè)文件夾下,如下圖所示:

接下來(lái),將pairs及對(duì)應(yīng)的label寫(xiě)到csv中,代碼如下:
import os
import random
import csv
#圖片所在的路徑
path = '/Users/mac/Desktop/wxd/flag/category/'
#files列表保存所有類(lèi)別的路徑
files=[]
same_pairs=[]
different_pairs=[]
for file in os.listdir(path):
if file[0]=='.':
continue
file_path = os.path.join(path,file)
files.append(file_path)
#該地址為csv要保存到的路徑,a表示追加寫(xiě)入
with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f:
#保存相同對(duì)
writer = csv.writer(f)
for file in files:
imgs = os.listdir(file)
for i in range(0,len(imgs)-1):
for j in range(i+1,len(imgs)):
pairs = []
name = file.split(sep='/')[-1]
pairs.append(path+name+'/'+imgs[i])
pairs.append(path+name+'/'+imgs[j])
pairs.append(1)
writer.writerow(pairs)
#保存不同對(duì)
for i in range(0,len(files)-1):
for j in range(i+1,len(files)):
filea = files[i]
fileb = files[j]
imga_li = os.listdir(filea)
imgb_li = os.listdir(fileb)
random.shuffle(imga_li)
random.shuffle(imgb_li)
a_li = imga_li[:]
b_li = imgb_li[:]
for p in range(len(a_li)):
for q in range(len(b_li)):
pairs = []
name1 = filea.split(sep='/')[-1]
name2 = fileb.split(sep='/')[-1]
pairs.append(path+name1+'/'+a_li[p])
pairs.append(path+name2+'/'+b_li[q])
pairs.append(0)
writer.writerow(pairs)
相當(dāng)于csv每一行都包含一對(duì)結(jié)果,每一行有三列,第一列第一張圖片路徑,第二列第二張圖片路徑,第三列是不是相同的label,屬于同一個(gè)類(lèi)的label為1,不同類(lèi)的為0,可參考下圖:

然后,由于keras的fit函數(shù)需要將訓(xùn)練數(shù)據(jù)都塞入內(nèi)存,而大部分訓(xùn)練數(shù)據(jù)都較大,因此才用fit_generator生成器的方法,便可以訓(xùn)練大數(shù)據(jù),代碼如下:
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
Activation, ZeroPadding2D
from keras.layers import add, Flatten
from keras.utils import plot_model
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import tensorflow as tf
import random
import os
import cv2
import csv
import numpy as np
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import img_to_array
"""
自定義的參數(shù)
"""
im_width = 224
im_height = 224
epochs = 100
batch_size = 64
iterations = 1000
csv_path = ''
model_result = ''
# 計(jì)算歐式距離
def euclidean_distance(vects):
x, y = vects
sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
return K.sqrt(K.maximum(sum_square, K.epsilon()))
def eucl_dist_output_shape(shapes):
shape1, shape2 = shapes
return (shape1[0], 1)
# 計(jì)算loss
def contrastive_loss(y_true, y_pred):
'''Contrastive loss from Hadsell-et-al.'06
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
'''
margin = 1
square_pred = K.square(y_pred)
margin_square = K.square(K.maximum(margin - y_pred, 0))
return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
def compute_accuracy(y_true, y_pred):
'''計(jì)算準(zhǔn)確率
'''
pred = y_pred.ravel() < 0.5
print('pred:', pred)
return np.mean(pred == y_true)
def accuracy(y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
'''
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
def processImg(filename):
"""
:param filename: 圖像的路徑
:return: 返回的是歸一化矩陣
"""
img = cv2.imread(filename)
img = cv2.resize(img, (im_width, im_height))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img_to_array(img)
img /= 255
return img
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None):
if name is not None:
bn_name = name + '_bn'
conv_name = name + '_conv'
else:
bn_name = None
conv_name = None
x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x)
x = BatchNormalization(axis=3, name=bn_name)(x)
return x
def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False):
k1, k2, k3 = nb_filters
x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same')
x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same')
x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same')
if with_conv_shortcut:
shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1)
x = add([x, shortcut])
return x
else:
x = add([x, inpt])
return x
def resnet_50():
width = im_width
height = im_height
channel = 3
inpt = Input(shape=(width, height, channel))
x = ZeroPadding2D((3, 3))(inpt)
x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
# conv2_x
x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True)
x = bottleneck_Block(x, nb_filters=[64, 64, 256])
x = bottleneck_Block(x, nb_filters=[64, 64, 256])
# conv3_x
x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True)
x = bottleneck_Block(x, nb_filters=[128, 128, 512])
x = bottleneck_Block(x, nb_filters=[128, 128, 512])
x = bottleneck_Block(x, nb_filters=[128, 128, 512])
# conv4_x
x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True)
x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
# conv5_x
x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True)
x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
x = AveragePooling2D(pool_size=(7, 7))(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
return Model(inpt, x)
def generator(imgs, batch_size):
"""
自定義迭代器
:param imgs: 列表,每個(gè)包含一對(duì)矩陣以及l(fā)abel
:param batch_size:
:return:
"""
while 1:
random.shuffle(imgs)
li = imgs[:batch_size]
pairs = []
labels = []
for i in li:
img1 = i[0]
img2 = i[1]
im1 = cv2.imread(img1)
im2 = cv2.imread(img2)
if im1 is None or im2 is None:
continue
label = int(i[2])
img1 = processImg(img1)
img2 = processImg(img2)
pairs.append([img1, img2])
labels.append(label)
pairs = np.array(pairs)
labels = np.array(labels)
yield [pairs[:, 0], pairs[:, 1]], labels
input_shape = (im_width, im_height, 3)
base_network = resnet_50()
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
distance = Lambda(euclidean_distance,
output_shape=eucl_dist_output_shape)([processed_a, processed_b])
with tf.device("/gpu:0"):
model = Model([input_a, input_b], distance)
# train
rms = RMSprop()
rows = csv.reader(open(csv_path, 'r'), delimiter=',')
imgs = list(rows)
checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1)
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])
用了回調(diào)函數(shù)保存了每一個(gè)epoch后的模型,也可以保存最好的,之后需要對(duì)模型進(jìn)行測(cè)試。
測(cè)試時(shí)直接用load_model會(huì)報(bào)錯(cuò),而應(yīng)該變成如下形式調(diào)用:
model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #選取自己的.h模型名稱(chēng)
emmm,到這里,就成功訓(xùn)練測(cè)試完了~~~寫(xiě)的比較粗,因?yàn)檫@個(gè)代碼在官方給的mnist上的改動(dòng)不大,只是方便大家用自己的數(shù)據(jù)集,大家如果有更好的方法可以提出意見(jiàn)~~~希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)操縱控制windows注冊(cè)表的方法分析
這篇文章主要介紹了Python實(shí)現(xiàn)操縱控制windows注冊(cè)表的方法,結(jié)合實(shí)例形式分析了Python使用_winreg模塊以及win32api模塊針對(duì)Windows注冊(cè)表操作相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2019-05-05
Python3交互式shell ipython3安裝及使用詳解
這篇文章主要介紹了Python3交互式shell ipython3安裝及使用詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-07-07
使用Python實(shí)現(xiàn)微信拍一拍功能的思路代碼
這篇文章主要介紹了使用Python實(shí)現(xiàn)微信“拍一拍”的思路代碼,,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-07-07
Python?matplotlib繪圖時(shí)使用鼠標(biāo)滾輪放大/縮小圖像
Matplotlib是Python程序員可用的事實(shí)上的繪圖庫(kù),雖然它比交互式繪圖庫(kù)在圖形上更簡(jiǎn)單,但它仍然可以一個(gè)強(qiáng)大的工具,下面這篇文章主要給大家介紹了關(guān)于Python?matplotlib繪圖時(shí)使用鼠標(biāo)滾輪放大/縮小圖像的相關(guān)資料,需要的朋友可以參考下2022-05-05
詳解OpenCV執(zhí)行連通分量標(biāo)記的方法和分析
在本教程中,您將學(xué)習(xí)如何使用?OpenCV?執(zhí)行連通分量標(biāo)記和分析。具體來(lái)說(shuō),我們將重點(diǎn)介紹?OpenCV?最常用的連通分量標(biāo)記函數(shù):cv2.connectedComponentsWithStats,感興趣的可以了解一下2022-08-08

