keras的siamese(孿生網(wǎng)絡(luò))實(shí)現(xiàn)案例
更新時(shí)間:2020年06月12日 14:20:25 作者:李上花開
這篇文章主要介紹了keras的siamese(孿生網(wǎng)絡(luò))實(shí)現(xiàn)案例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
代碼位于keras的官方樣例,并做了微量修改和大量學(xué)習(xí)?。
最終效果:
import keras import numpy as np import matplotlib.pyplot as plt import random from keras.callbacks import TensorBoard from keras.datasets import mnist from keras.models import Model from keras.layers import Input, Flatten, Dense, Dropout, Lambda from keras.optimizers import RMSprop from keras import backend as K num_classes = 10 epochs = 20 def euclidean_distance(vects): x, y = vects sum_square = K.sum(K.square(x - y), axis=1, keepdims=True) return K.sqrt(K.maximum(sum_square, K.epsilon())) def eucl_dist_output_shape(shapes): shape1, shape2 = shapes return (shape1[0], 1) def contrastive_loss(y_true, y_pred): '''Contrastive loss from Hadsell-et-al.'06 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf ''' margin = 1 sqaure_pred = K.square(y_pred) margin_square = K.square(K.maximum(margin - y_pred, 0)) return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square) def create_pairs(x, digit_indices): '''Positive and negative pair creation. Alternates between positive and negative pairs. ''' pairs = [] labels = [] n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1 for d in range(num_classes): for i in range(n): z1, z2 = digit_indices[d][i], digit_indices[d][i + 1] pairs += [[x[z1], x[z2]]] inc = random.randrange(1, num_classes) dn = (d + inc) % num_classes z1, z2 = digit_indices[d][i], digit_indices[dn][i] pairs += [[x[z1], x[z2]]] labels += [1, 0] return np.array(pairs), np.array(labels) def create_base_network(input_shape): '''Base network to be shared (eq. to feature extraction). ''' input = Input(shape=input_shape) x = Flatten()(input) x = Dense(128, activation='relu')(x) x = Dropout(0.1)(x) x = Dense(128, activation='relu')(x) x = Dropout(0.1)(x) x = Dense(128, activation='relu')(x) return Model(input, x) def compute_accuracy(y_true, y_pred): # numpy上的操作 '''Compute classification accuracy with a fixed threshold on distances. ''' pred = y_pred.ravel() < 0.5 return np.mean(pred == y_true) def accuracy(y_true, y_pred): # Tensor上的操作 '''Compute classification accuracy with a fixed threshold on distances. ''' return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype))) def plot_train_history(history, train_metrics, val_metrics): plt.plot(history.history.get(train_metrics), '-o') plt.plot(history.history.get(val_metrics), '-o') plt.ylabel(train_metrics) plt.xlabel('Epochs') plt.legend(['train', 'validation']) # the data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 input_shape = x_train.shape[1:] # create training+test positive and negative pairs digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)] tr_pairs, tr_y = create_pairs(x_train, digit_indices) digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)] te_pairs, te_y = create_pairs(x_test, digit_indices) # network definition base_network = create_base_network(input_shape) input_a = Input(shape=input_shape) input_b = Input(shape=input_shape) # because we re-use the same instance `base_network`, # the weights of the network # will be shared across the two branches processed_a = base_network(input_a) processed_b = base_network(input_b) distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([processed_a, processed_b]) model = Model([input_a, input_b], distance) keras.utils.plot_model(model,"siamModel.png",show_shapes=True) model.summary() # train rms = RMSprop() model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy]) history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y, batch_size=128, epochs=epochs,verbose=2, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)) plt.figure(figsize=(8, 4)) plt.subplot(1, 2, 1) plot_train_history(history, 'loss', 'val_loss') plt.subplot(1, 2, 2) plot_train_history(history, 'accuracy', 'val_accuracy') plt.show() # compute final accuracy on training and test sets y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]]) tr_acc = compute_accuracy(tr_y, y_pred) y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]]) te_acc = compute_accuracy(te_y, y_pred) print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc)) print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
以上這篇keras的siamese(孿生網(wǎng)絡(luò))實(shí)現(xiàn)案例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)單項(xiàng)鏈表的最全教程
單向鏈表也叫單鏈表,是鏈表中最簡(jiǎn)單的一種形式,它的每個(gè)節(jié)點(diǎn)包含兩個(gè)域,一個(gè)信息域(元素域)和一個(gè)鏈接域,這個(gè)鏈接指向鏈表中的下一個(gè)節(jié)點(diǎn),而最后一個(gè)節(jié)點(diǎn)的鏈接域則指向一個(gè)空值,這篇文章主要介紹了Python實(shí)現(xiàn)單項(xiàng)鏈表,需要的朋友可以參考下2023-01-01Keras使用ImageNet上預(yù)訓(xùn)練的模型方式
這篇文章主要介紹了Keras使用ImageNet上預(yù)訓(xùn)練的模型方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05Python編程itertools模塊處理可迭代集合相關(guān)函數(shù)
本篇博客將為你介紹Python函數(shù)式編程itertools模塊中處理可迭代集合的相關(guān)函數(shù),有需要的朋友可以借鑒參考下,希望可以有所幫助2021-09-09darknet框架中YOLOv3對(duì)數(shù)據(jù)集進(jìn)行訓(xùn)練和預(yù)測(cè)詳解
這篇文章主要為大家介紹了darknet框架中YOLOv3對(duì)數(shù)據(jù)集進(jìn)行訓(xùn)練和預(yù)測(cè)使用詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-11-11Python數(shù)據(jù)可視化常用4大繪圖庫(kù)原理詳解
這篇文章主要介紹了Python數(shù)據(jù)可視化常用4大繪圖庫(kù)原理詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10