keras-siamese用自己的數(shù)據(jù)集實現(xiàn)詳解
Siamese網(wǎng)絡(luò)不做過多介紹,思想并不難,輸入兩個圖像,輸出這兩張圖像的相似度,兩個輸入的網(wǎng)絡(luò)結(jié)構(gòu)是相同的,參數(shù)共享。
主要發(fā)現(xiàn)很多代碼都是基于mnist數(shù)據(jù)集的,下面說一下怎么用自己的數(shù)據(jù)集實現(xiàn)siamese網(wǎng)絡(luò)。
首先,先整理數(shù)據(jù)集,相同的類放到同一個文件夾下,如下圖所示:
接下來,將pairs及對應(yīng)的label寫到csv中,代碼如下:
import os import random import csv #圖片所在的路徑 path = '/Users/mac/Desktop/wxd/flag/category/' #files列表保存所有類別的路徑 files=[] same_pairs=[] different_pairs=[] for file in os.listdir(path): if file[0]=='.': continue file_path = os.path.join(path,file) files.append(file_path) #該地址為csv要保存到的路徑,a表示追加寫入 with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f: #保存相同對 writer = csv.writer(f) for file in files: imgs = os.listdir(file) for i in range(0,len(imgs)-1): for j in range(i+1,len(imgs)): pairs = [] name = file.split(sep='/')[-1] pairs.append(path+name+'/'+imgs[i]) pairs.append(path+name+'/'+imgs[j]) pairs.append(1) writer.writerow(pairs) #保存不同對 for i in range(0,len(files)-1): for j in range(i+1,len(files)): filea = files[i] fileb = files[j] imga_li = os.listdir(filea) imgb_li = os.listdir(fileb) random.shuffle(imga_li) random.shuffle(imgb_li) a_li = imga_li[:] b_li = imgb_li[:] for p in range(len(a_li)): for q in range(len(b_li)): pairs = [] name1 = filea.split(sep='/')[-1] name2 = fileb.split(sep='/')[-1] pairs.append(path+name1+'/'+a_li[p]) pairs.append(path+name2+'/'+b_li[q]) pairs.append(0) writer.writerow(pairs)
相當(dāng)于csv每一行都包含一對結(jié)果,每一行有三列,第一列第一張圖片路徑,第二列第二張圖片路徑,第三列是不是相同的label,屬于同一個類的label為1,不同類的為0,可參考下圖:
然后,由于keras的fit函數(shù)需要將訓(xùn)練數(shù)據(jù)都塞入內(nèi)存,而大部分訓(xùn)練數(shù)據(jù)都較大,因此才用fit_generator生成器的方法,便可以訓(xùn)練大數(shù)據(jù),代碼如下:
from __future__ import absolute_import from __future__ import print_function import numpy as np from keras.models import Model from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \ Activation, ZeroPadding2D from keras.layers import add, Flatten from keras.utils import plot_model from keras.metrics import top_k_categorical_accuracy from keras.preprocessing.image import ImageDataGenerator from keras.models import load_model import tensorflow as tf import random import os import cv2 import csv import numpy as np from keras.models import Model from keras.layers import Input, Flatten, Dense, Dropout, Lambda from keras.optimizers import RMSprop from keras import backend as K from keras.callbacks import ModelCheckpoint from keras.preprocessing.image import img_to_array """ 自定義的參數(shù) """ im_width = 224 im_height = 224 epochs = 100 batch_size = 64 iterations = 1000 csv_path = '' model_result = '' # 計算歐式距離 def euclidean_distance(vects): x, y = vects sum_square = K.sum(K.square(x - y), axis=1, keepdims=True) return K.sqrt(K.maximum(sum_square, K.epsilon())) def eucl_dist_output_shape(shapes): shape1, shape2 = shapes return (shape1[0], 1) # 計算loss def contrastive_loss(y_true, y_pred): '''Contrastive loss from Hadsell-et-al.'06 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf ''' margin = 1 square_pred = K.square(y_pred) margin_square = K.square(K.maximum(margin - y_pred, 0)) return K.mean(y_true * square_pred + (1 - y_true) * margin_square) def compute_accuracy(y_true, y_pred): '''計算準確率 ''' pred = y_pred.ravel() < 0.5 print('pred:', pred) return np.mean(pred == y_true) def accuracy(y_true, y_pred): '''Compute classification accuracy with a fixed threshold on distances. ''' return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype))) def processImg(filename): """ :param filename: 圖像的路徑 :return: 返回的是歸一化矩陣 """ img = cv2.imread(filename) img = cv2.resize(img, (im_width, im_height)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img_to_array(img) img /= 255 return img def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None): if name is not None: bn_name = name + '_bn' conv_name = name + '_conv' else: bn_name = None conv_name = None x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x) x = BatchNormalization(axis=3, name=bn_name)(x) return x def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False): k1, k2, k3 = nb_filters x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same') x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same') x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same') if with_conv_shortcut: shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1) x = add([x, shortcut]) return x else: x = add([x, inpt]) return x def resnet_50(): width = im_width height = im_height channel = 3 inpt = Input(shape=(width, height, channel)) x = ZeroPadding2D((3, 3))(inpt) x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid') x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x) # conv2_x x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True) x = bottleneck_Block(x, nb_filters=[64, 64, 256]) x = bottleneck_Block(x, nb_filters=[64, 64, 256]) # conv3_x x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True) x = bottleneck_Block(x, nb_filters=[128, 128, 512]) x = bottleneck_Block(x, nb_filters=[128, 128, 512]) x = bottleneck_Block(x, nb_filters=[128, 128, 512]) # conv4_x x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True) x = bottleneck_Block(x, nb_filters=[256, 256, 1024]) x = bottleneck_Block(x, nb_filters=[256, 256, 1024]) x = bottleneck_Block(x, nb_filters=[256, 256, 1024]) x = bottleneck_Block(x, nb_filters=[256, 256, 1024]) x = bottleneck_Block(x, nb_filters=[256, 256, 1024]) # conv5_x x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True) x = bottleneck_Block(x, nb_filters=[512, 512, 2048]) x = bottleneck_Block(x, nb_filters=[512, 512, 2048]) x = AveragePooling2D(pool_size=(7, 7))(x) x = Flatten()(x) x = Dense(128, activation='relu')(x) return Model(inpt, x) def generator(imgs, batch_size): """ 自定義迭代器 :param imgs: 列表,每個包含一對矩陣以及l(fā)abel :param batch_size: :return: """ while 1: random.shuffle(imgs) li = imgs[:batch_size] pairs = [] labels = [] for i in li: img1 = i[0] img2 = i[1] im1 = cv2.imread(img1) im2 = cv2.imread(img2) if im1 is None or im2 is None: continue label = int(i[2]) img1 = processImg(img1) img2 = processImg(img2) pairs.append([img1, img2]) labels.append(label) pairs = np.array(pairs) labels = np.array(labels) yield [pairs[:, 0], pairs[:, 1]], labels input_shape = (im_width, im_height, 3) base_network = resnet_50() input_a = Input(shape=input_shape) input_b = Input(shape=input_shape) # because we re-use the same instance `base_network`, # the weights of the network # will be shared across the two branches processed_a = base_network(input_a) processed_b = base_network(input_b) distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([processed_a, processed_b]) with tf.device("/gpu:0"): model = Model([input_a, input_b], distance) # train rms = RMSprop() rows = csv.reader(open(csv_path, 'r'), delimiter=',') imgs = list(rows) checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1) model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy]) model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])
用了回調(diào)函數(shù)保存了每一個epoch后的模型,也可以保存最好的,之后需要對模型進行測試。
測試時直接用load_model會報錯,而應(yīng)該變成如下形式調(diào)用:
model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #選取自己的.h模型名稱
emmm,到這里,就成功訓(xùn)練測試完了~~~寫的比較粗,因為這個代碼在官方給的mnist上的改動不大,只是方便大家用自己的數(shù)據(jù)集,大家如果有更好的方法可以提出意見~~~希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn)操縱控制windows注冊表的方法分析
這篇文章主要介紹了Python實現(xiàn)操縱控制windows注冊表的方法,結(jié)合實例形式分析了Python使用_winreg模塊以及win32api模塊針對Windows注冊表操作相關(guān)實現(xiàn)技巧,需要的朋友可以參考下2019-05-05Python3交互式shell ipython3安裝及使用詳解
這篇文章主要介紹了Python3交互式shell ipython3安裝及使用詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-07-07Python?matplotlib繪圖時使用鼠標滾輪放大/縮小圖像
Matplotlib是Python程序員可用的事實上的繪圖庫,雖然它比交互式繪圖庫在圖形上更簡單,但它仍然可以一個強大的工具,下面這篇文章主要給大家介紹了關(guān)于Python?matplotlib繪圖時使用鼠標滾輪放大/縮小圖像的相關(guān)資料,需要的朋友可以參考下2022-05-05