Keras自動下載的數(shù)據(jù)集/模型存放位置介紹
Mac
# 數(shù)據(jù)集
~/.keras/datasets/# 模型
~/.keras/models/
Linux
# 數(shù)據(jù)集
~/.keras/datasets/
Windows
# win10
C:\Users\user_name\.keras\datasets
補充知識:Keras_gan生成自己的數(shù)據(jù),并保存模型
我就廢話不多說了,大家還是直接看代碼吧~
from __future__ import print_function, division from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import os import matplotlib.pyplot as plt import sys import numpy as np class GAN(): def __init__(self): self.img_rows = 3 self.img_cols = 60 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 optimizer = Adam(0.0002, 0.5) # 構(gòu)建和編譯判別器 self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # 構(gòu)建生成器 self.generator = self.build_generator() # 生成器輸入噪音,生成假的圖片 z = Input(shape=(self.latent_dim,)) img = self.generator(z) # 為了組合模型,只訓(xùn)練生成器 self.discriminator.trainable = False # 判別器將生成的圖像作為輸入并確定有效性 validity = self.discriminator(img) # The combined model (stacked generator and discriminator) # 訓(xùn)練生成器騙過判別器 self.combined = Model(z, validity) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): model = Sequential() model.add(Dense(64, input_dim=self.latent_dim)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(128)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) #np.prod(self.img_shape)=3x60x1 model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() noise = Input(shape=(self.latent_dim,)) img = model(noise) #輸入噪音,輸出圖片 return Model(noise, img) def build_discriminator(self): model = Sequential() model.add(Flatten(input_shape=self.img_shape)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(128)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(64)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary() img = Input(shape=self.img_shape) validity = model(img) return Model(img, validity) def train(self, epochs, batch_size=128, sample_interval=50): ############################################################ #自己數(shù)據(jù)集此部分需要更改 # 加載數(shù)據(jù)集 data = np.load('data/相對大小分叉.npy') data = data[:,:,0:60] # 歸一化到-1到1 data = data * 2 - 1 data = np.expand_dims(data, axis=3) ############################################################ # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # --------------------- # 訓(xùn)練判別器 # --------------------- # data.shape[0]為數(shù)據(jù)集的數(shù)量,隨機生成batch_size個數(shù)量的隨機數(shù),作為數(shù)據(jù)的索引 idx = np.random.randint(0, data.shape[0], batch_size) #從數(shù)據(jù)集隨機挑選batch_size個數(shù)據(jù),作為一個批次訓(xùn)練 imgs = data[idx] #噪音維度(batch_size,100) noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # 由生成器根據(jù)噪音生成假的圖片 gen_imgs = self.generator.predict(noise) # 訓(xùn)練判別器,判別器希望真實圖片,打上標(biāo)簽1,假的圖片打上標(biāo)簽0 d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # 訓(xùn)練生成器 # --------------------- noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise, valid) # 打印loss值 print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) # 沒sample_interval個epoch保存一次生成圖片 if epoch % sample_interval == 0: self.sample_images(epoch) if not os.path.exists("keras_model"): os.makedirs("keras_model") self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True) self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True) def sample_images(self, epoch): r, c = 10, 10 # 重新生成一批噪音,維度為(100,100) noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) # 將生成的圖片重新歸整到0-1之間 gen = 0.5 * gen_imgs + 0.5 gen = gen.reshape(-1,3,60) fig,axs = plt.subplots(r,c) cnt = 0 for i in range(r): for j in range(c): xy = gen[cnt] for k in range(len(xy)): x = xy[k][0:30] y = xy[k][30:60] if k == 0: axs[i,j].plot(x,y,color='blue') if k == 1: axs[i,j].plot(x,y,color='red') if k == 2: axs[i,j].plot(x,y,color='green') plt.xlim(0.,1.) plt.ylim(0.,1.) plt.xticks(np.arange(0,1,0.1)) plt.xticks(np.arange(0,1,0.1)) axs[i,j].axis('off') cnt += 1 if not os.path.exists("keras_imgs"): os.makedirs("keras_imgs") fig.savefig("keras_imgs/%d.png" % epoch) plt.close() def test(self,gen_nums=100,save=False): self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True) self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True) noise = np.random.normal(0,1,(gen_nums,self.latent_dim)) gen = self.generator.predict(noise) gen = 0.5 * gen + 0.5 gen = gen.reshape(-1,3,60) print(gen.shape) ############################################################### #直接可視化生成圖片 if save: for i in range(0,len(gen)): plt.figure(figsize=(128,128),dpi=1) plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300) plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300) plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300) plt.axis('off') plt.xlim(0.,1.) plt.ylim(0.,1.) plt.xticks(np.arange(0,1,0.1)) plt.yticks(np.arange(0,1,0.1)) if not os.path.exists("keras_gen"): os.makedirs("keras_gen") plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1) plt.close() ################################################################## #重整圖片到0-1 else: for i in range(len(gen)): plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue') plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red') plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green') plt.xlim(0.,1.) plt.ylim(0.,1.) plt.xticks(np.arange(0,1,0.1)) plt.xticks(np.arange(0,1,0.1)) plt.show() if __name__ == '__main__': gan = GAN() gan.train(epochs=300000, batch_size=32, sample_interval=2000) # gan.test(save=True)
以上這篇Keras自動下載的數(shù)據(jù)集/模型存放位置介紹就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn)優(yōu)雅編寫LaTeX的示例代碼
LaTeX?是一種廣泛用于排版學(xué)術(shù)論文、報告、書籍和演示文稿的標(biāo)記語言,本文主要為大家詳細(xì)介紹了如何使用?Python?來優(yōu)雅地編寫?LaTeX,提高效率并減少錯誤,需要的可以參考下2024-02-02python 將字符串中的數(shù)字相加求和的實現(xiàn)
這篇文章主要介紹了python 將字符串中的數(shù)字相加求和的實現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07python虛擬機pyc文件結(jié)構(gòu)的深入理解
這篇文章主要為大家介紹了python虛擬機之pyc文件結(jié)構(gòu)的深入探究理解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-03-03利用Celery實現(xiàn)Django博客PV統(tǒng)計功能詳解
給網(wǎng)站增加pv、uv統(tǒng)計,可以是件很簡單的事,也可以是件很復(fù)雜的事。下面這篇文章主要給大家介紹了利用Celery實現(xiàn)Django博客PV統(tǒng)計功能的相關(guān)資料,文中介紹的非常詳細(xì),需要的朋友可以參考借鑒,下面來一起看看吧。2017-05-05圖文講解選擇排序算法的原理及在Python中的實現(xiàn)
這篇文章主要介紹了選擇排序的原理及在Python中的實現(xiàn),選擇排序的時間復(fù)雜度為О(n²),需要的朋友可以參考下2016-05-05