欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

keras-siamese用自己的數(shù)據(jù)集實現(xiàn)詳解

 更新時間:2020年06月10日 09:15:29   作者:莫離已成歌  
這篇文章主要介紹了keras-siamese用自己的數(shù)據(jù)集實現(xiàn)詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

Siamese網(wǎng)絡(luò)不做過多介紹,思想并不難,輸入兩個圖像,輸出這兩張圖像的相似度,兩個輸入的網(wǎng)絡(luò)結(jié)構(gòu)是相同的,參數(shù)共享。

主要發(fā)現(xiàn)很多代碼都是基于mnist數(shù)據(jù)集的,下面說一下怎么用自己的數(shù)據(jù)集實現(xiàn)siamese網(wǎng)絡(luò)。

首先,先整理數(shù)據(jù)集,相同的類放到同一個文件夾下,如下圖所示:

接下來,將pairs及對應(yīng)的label寫到csv中,代碼如下:

import os
import random
import csv
#圖片所在的路徑
path = '/Users/mac/Desktop/wxd/flag/category/'
#files列表保存所有類別的路徑
files=[]
same_pairs=[]
different_pairs=[]
for file in os.listdir(path):
 if file[0]=='.':
  continue
 file_path = os.path.join(path,file)
 files.append(file_path)
#該地址為csv要保存到的路徑,a表示追加寫入
with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f:
 #保存相同對
 writer = csv.writer(f)
 for file in files:
  imgs = os.listdir(file) 
  for i in range(0,len(imgs)-1):
   for j in range(i+1,len(imgs)):
    pairs = []
    name = file.split(sep='/')[-1]
    pairs.append(path+name+'/'+imgs[i])
    pairs.append(path+name+'/'+imgs[j])
    pairs.append(1)
    writer.writerow(pairs)
 #保存不同對
 for i in range(0,len(files)-1):
  for j in range(i+1,len(files)):
   filea = files[i]
   fileb = files[j]
   imga_li = os.listdir(filea)
   imgb_li = os.listdir(fileb)
   random.shuffle(imga_li)
   random.shuffle(imgb_li)
   a_li = imga_li[:]
   b_li = imgb_li[:]
   for p in range(len(a_li)):
    for q in range(len(b_li)):
     pairs = []
     name1 = filea.split(sep='/')[-1]
     name2 = fileb.split(sep='/')[-1]
     pairs.append(path+name1+'/'+a_li[p])
     pairs.append(path+name2+'/'+b_li[q])
     pairs.append(0)
     writer.writerow(pairs)

相當(dāng)于csv每一行都包含一對結(jié)果,每一行有三列,第一列第一張圖片路徑,第二列第二張圖片路徑,第三列是不是相同的label,屬于同一個類的label為1,不同類的為0,可參考下圖:

然后,由于keras的fit函數(shù)需要將訓(xùn)練數(shù)據(jù)都塞入內(nèi)存,而大部分訓(xùn)練數(shù)據(jù)都較大,因此才用fit_generator生成器的方法,便可以訓(xùn)練大數(shù)據(jù),代碼如下:

from __future__ import absolute_import
from __future__ import print_function
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
 Activation, ZeroPadding2D
from keras.layers import add, Flatten
from keras.utils import plot_model
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import tensorflow as tf
import random
import os
import cv2
import csv
import numpy as np
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import img_to_array
 
"""
自定義的參數(shù)
"""
im_width = 224
im_height = 224
epochs = 100
batch_size = 64
iterations = 1000
csv_path = ''
model_result = ''
 
 
# 計算歐式距離
def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))
 
def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)
 
# 計算loss
def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 square_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
 
def compute_accuracy(y_true, y_pred):
 '''計算準確率
 '''
 pred = y_pred.ravel() < 0.5
 print('pred:', pred)
 return np.mean(pred == y_true)
 
def accuracy(y_true, y_pred):
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
 
def processImg(filename):
 """
 :param filename: 圖像的路徑
 :return: 返回的是歸一化矩陣
 """
 img = cv2.imread(filename)
 img = cv2.resize(img, (im_width, im_height))
 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 img = img_to_array(img)
 img /= 255
 return img
 
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None):
 if name is not None:
  bn_name = name + '_bn'
  conv_name = name + '_conv'
 else:
  bn_name = None
  conv_name = None
 
 x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x)
 x = BatchNormalization(axis=3, name=bn_name)(x)
 return x
 
def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False):
 k1, k2, k3 = nb_filters
 x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same')
 x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same')
 x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same')
 if with_conv_shortcut:
  shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1)
  x = add([x, shortcut])
  return x
 else:
  x = add([x, inpt])
  return x
 
def resnet_50():
 width = im_width
 height = im_height
 channel = 3
 inpt = Input(shape=(width, height, channel))
 x = ZeroPadding2D((3, 3))(inpt)
 x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
 x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
 
 # conv2_x
 x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 
 # conv3_x
 x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 
 # conv4_x
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 
 # conv5_x
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 
 x = AveragePooling2D(pool_size=(7, 7))(x)
 x = Flatten()(x)
 x = Dense(128, activation='relu')(x)
 return Model(inpt, x)
 
def generator(imgs, batch_size):
 """
 自定義迭代器
 :param imgs: 列表,每個包含一對矩陣以及l(fā)abel
 :param batch_size:
 :return:
 """
 while 1:
  random.shuffle(imgs)
  li = imgs[:batch_size]
  pairs = []
  labels = []
  for i in li:
   img1 = i[0]
   img2 = i[1]
   im1 = cv2.imread(img1)
   im2 = cv2.imread(img2)
   if im1 is None or im2 is None:
    continue
   label = int(i[2])
   img1 = processImg(img1)
   img2 = processImg(img2)
   pairs.append([img1, img2])
   labels.append(label)
  pairs = np.array(pairs)
  labels = np.array(labels)
  yield [pairs[:, 0], pairs[:, 1]], labels
 
input_shape = (im_width, im_height, 3)
base_network = resnet_50()
 
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
 
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
 
distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])
with tf.device("/gpu:0"):
 model = Model([input_a, input_b], distance)
 # train
 rms = RMSprop()
 rows = csv.reader(open(csv_path, 'r'), delimiter=',')
 imgs = list(rows)
 checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1)
 model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
 model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])

用了回調(diào)函數(shù)保存了每一個epoch后的模型,也可以保存最好的,之后需要對模型進行測試。

測試時直接用load_model會報錯,而應(yīng)該變成如下形式調(diào)用:

model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #選取自己的.h模型名稱

emmm,到這里,就成功訓(xùn)練測試完了~~~寫的比較粗,因為這個代碼在官方給的mnist上的改動不大,只是方便大家用自己的數(shù)據(jù)集,大家如果有更好的方法可以提出意見~~~希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python使用thrift教程的方法示例

    python使用thrift教程的方法示例

    這篇文章主要介紹了python使用thrift教程的方法示例,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-03-03
  • Python實現(xiàn)操縱控制windows注冊表的方法分析

    Python實現(xiàn)操縱控制windows注冊表的方法分析

    這篇文章主要介紹了Python實現(xiàn)操縱控制windows注冊表的方法,結(jié)合實例形式分析了Python使用_winreg模塊以及win32api模塊針對Windows注冊表操作相關(guān)實現(xiàn)技巧,需要的朋友可以參考下
    2019-05-05
  • Python中三種花式打印的示例詳解

    Python中三種花式打印的示例詳解

    在Python中有很多好玩的花式打印,我們今天就來挑戰(zhàn)下面三個常見的花式打印。文中的示例代碼講解詳細,感興趣的小伙伴快跟隨小編一起學(xué)習(xí)一下吧
    2022-03-03
  • 對Python 數(shù)組的切片操作詳解

    對Python 數(shù)組的切片操作詳解

    今天小編就為大家分享一篇對Python 數(shù)組的切片操作詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-07-07
  • Python3交互式shell ipython3安裝及使用詳解

    Python3交互式shell ipython3安裝及使用詳解

    這篇文章主要介紹了Python3交互式shell ipython3安裝及使用詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2020-07-07
  • python分析網(wǎng)頁上所有超鏈接的方法

    python分析網(wǎng)頁上所有超鏈接的方法

    這篇文章主要介紹了python分析網(wǎng)頁上所有超鏈接的方法,涉及Python使用urllib模塊操作頁面超鏈接的技巧,需要的朋友可以參考下
    2015-05-05
  • 使用Python實現(xiàn)微信拍一拍功能的思路代碼

    使用Python實現(xiàn)微信拍一拍功能的思路代碼

    這篇文章主要介紹了使用Python實現(xiàn)微信“拍一拍”的思路代碼,,本文通過示例代碼給大家介紹的非常詳細,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2020-07-07
  • python常見的格式化輸出小結(jié)

    python常見的格式化輸出小結(jié)

    今天在寫代碼的時候,需要統(tǒng)一化輸出格式進行,一時想不起竟具體細節(jié),用了最笨的方法,所以覺得有必要將常見的方法進行一個總結(jié)。下面這篇文中就給大家總結(jié)了python中常見的格式化輸出,比如打印字符串、打印整數(shù)和打印浮點數(shù)等,下面來看看詳細的輸出方法吧。
    2016-12-12
  • Python?matplotlib繪圖時使用鼠標滾輪放大/縮小圖像

    Python?matplotlib繪圖時使用鼠標滾輪放大/縮小圖像

    Matplotlib是Python程序員可用的事實上的繪圖庫,雖然它比交互式繪圖庫在圖形上更簡單,但它仍然可以一個強大的工具,下面這篇文章主要給大家介紹了關(guān)于Python?matplotlib繪圖時使用鼠標滾輪放大/縮小圖像的相關(guān)資料,需要的朋友可以參考下
    2022-05-05
  • 詳解OpenCV執(zhí)行連通分量標記的方法和分析

    詳解OpenCV執(zhí)行連通分量標記的方法和分析

    在本教程中,您將學(xué)習(xí)如何使用?OpenCV?執(zhí)行連通分量標記和分析。具體來說,我們將重點介紹?OpenCV?最常用的連通分量標記函數(shù):cv2.connectedComponentsWithStats,感興趣的可以了解一下
    2022-08-08

最新評論