欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

keras的siamese(孿生網(wǎng)絡(luò))實(shí)現(xiàn)案例

 更新時(shí)間:2020年06月12日 14:20:25   作者:李上花開  
這篇文章主要介紹了keras的siamese(孿生網(wǎng)絡(luò))實(shí)現(xiàn)案例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

代碼位于keras的官方樣例,并做了微量修改和大量學(xué)習(xí)?。

最終效果:

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上這篇keras的siamese(孿生網(wǎng)絡(luò))實(shí)現(xiàn)案例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python字節(jié)串類型bytes及用法

    Python字節(jié)串類型bytes及用法

    這篇文章介紹了Python字節(jié)串類型bytes及用法,文中通過(guò)示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-05-05
  • Python實(shí)現(xiàn)單項(xiàng)鏈表的最全教程

    Python實(shí)現(xiàn)單項(xiàng)鏈表的最全教程

    單向鏈表也叫單鏈表,是鏈表中最簡(jiǎn)單的一種形式,它的每個(gè)節(jié)點(diǎn)包含兩個(gè)域,一個(gè)信息域(元素域)和一個(gè)鏈接域,這個(gè)鏈接指向鏈表中的下一個(gè)節(jié)點(diǎn),而最后一個(gè)節(jié)點(diǎn)的鏈接域則指向一個(gè)空值,這篇文章主要介紹了Python實(shí)現(xiàn)單項(xiàng)鏈表,需要的朋友可以參考下
    2023-01-01
  • Keras使用ImageNet上預(yù)訓(xùn)練的模型方式

    Keras使用ImageNet上預(yù)訓(xùn)練的模型方式

    這篇文章主要介紹了Keras使用ImageNet上預(yù)訓(xùn)練的模型方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2020-05-05
  • Pytorch中的torch.where函數(shù)使用

    Pytorch中的torch.where函數(shù)使用

    這篇文章主要介紹了Pytorch中的torch.where函數(shù)使用方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2024-02-02
  • Python編程itertools模塊處理可迭代集合相關(guān)函數(shù)

    Python編程itertools模塊處理可迭代集合相關(guān)函數(shù)

    本篇博客將為你介紹Python函數(shù)式編程itertools模塊中處理可迭代集合的相關(guān)函數(shù),有需要的朋友可以借鑒參考下,希望可以有所幫助
    2021-09-09
  • python制作一個(gè)桌面便簽軟件

    python制作一個(gè)桌面便簽軟件

    這篇文章主要介紹了python制作一個(gè)桌面便簽軟件分別給大家附上ubuntu和windows版的程序及源碼,有需要的小伙伴可以參考下。
    2015-08-08
  • darknet框架中YOLOv3對(duì)數(shù)據(jù)集進(jìn)行訓(xùn)練和預(yù)測(cè)詳解

    darknet框架中YOLOv3對(duì)數(shù)據(jù)集進(jìn)行訓(xùn)練和預(yù)測(cè)詳解

    這篇文章主要為大家介紹了darknet框架中YOLOv3對(duì)數(shù)據(jù)集進(jìn)行訓(xùn)練和預(yù)測(cè)使用詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-11-11
  • python中通過(guò)Django捕獲所有異常的處理

    python中通過(guò)Django捕獲所有異常的處理

    誠(chéng)然,每個(gè)人都會(huì)寫bug,程序拋異常是一件很正常的事;既然異常總是會(huì)拋,那就想辦法在拋出后,盡早解決才是王道。不能老是等待用戶反饋異常和問題,萬(wàn)一用戶懶得反饋了,豈不很尷尬
    2021-09-09
  • Python數(shù)據(jù)可視化常用4大繪圖庫(kù)原理詳解

    Python數(shù)據(jù)可視化常用4大繪圖庫(kù)原理詳解

    這篇文章主要介紹了Python數(shù)據(jù)可視化常用4大繪圖庫(kù)原理詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-10-10
  • set在python里的含義和用法

    set在python里的含義和用法

    在本篇內(nèi)容中我們給大家整理了關(guān)于set在python里的用法含義等相關(guān)知識(shí)點(diǎn)內(nèi)容,有興趣的朋友們可以學(xué)習(xí)下。
    2019-06-06

最新評(píng)論