欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

keras打印loss對權(quán)重的導(dǎo)數(shù)方式

 更新時間:2020年06月10日 09:25:49   作者:HackerTom  
這篇文章主要介紹了keras打印loss對權(quán)重的導(dǎo)數(shù)方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

Notes

懷疑模型梯度爆炸,想打印模型 loss 對各權(quán)重的導(dǎo)數(shù)看看。如果如果fit來訓(xùn)練的話,可以用keras.callbacks.TensorBoard實現(xiàn)。

但此次使用train_on_batch來訓(xùn)練的,用K.gradients和K.function實現(xiàn)。

Codes

以一份 VAE 代碼為例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一個參數(shù)
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 獲取模型權(quán)重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 對權(quán)重的導(dǎo)數(shù)
# KL 要是 Tensor,不能是上面的函數(shù) `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然報錯

# 打印梯度的函數(shù)
# K.function 的輸入和輸出必要是 list!就算只有一個
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能傳 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式訓(xùn)練 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 計算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必須秩為 2,兩層中括號
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

補充知識:keras 自定義損失 自動求導(dǎo)時出現(xiàn)None

問題記錄,keras 自定義損失 自動求導(dǎo)時出現(xiàn)None,后來想到是因為傳入的變量沒有使用,所以keras無法求出偏導(dǎo),修改后問題解決。就是不愿使用的變量×0,求導(dǎo)后還是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上這篇keras打印loss對權(quán)重的導(dǎo)數(shù)方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python調(diào)用騰訊云短信服務(wù)發(fā)送手機(jī)短信

    Python調(diào)用騰訊云短信服務(wù)發(fā)送手機(jī)短信

    這篇文章主要為大家介紹了Python調(diào)用騰訊云短信服務(wù)發(fā)送手機(jī)短信,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-05-05
  • python刪除列表中重復(fù)記錄的方法

    python刪除列表中重復(fù)記錄的方法

    這篇文章主要介紹了python刪除列表中重復(fù)記錄的方法,涉及Python操作列表的相關(guān)技巧,需要的朋友可以參考下
    2015-04-04
  • 3個適合新手練習(xí)的python小游戲

    3個適合新手練習(xí)的python小游戲

    這篇文章主要分析的是3個適合新手練習(xí)的python小游戲,初學(xué)者嘛就應(yīng)該多練手,下文分享的python小游戲歡迎大家來玩,需要的小伙伴也可以參考一下
    2022-01-01
  • Python設(shè)計模式行為型責(zé)任鏈模式

    Python設(shè)計模式行為型責(zé)任鏈模式

    這篇文章主要介紹了Python設(shè)計模式行為型責(zé)任鏈模式,責(zé)任鏈模式將能處理請求的對象連成一條鏈,并沿著這條鏈傳遞該請求,直到有一個對象處理請求為止,避免請求的發(fā)送者和接收者之間的耦合關(guān)系,下圍繞改內(nèi)容介紹具有一點的參考價值,需要的朋友可以參考下
    2022-02-02
  • Python中使用第三方庫xlutils來追加寫入Excel文件示例

    Python中使用第三方庫xlutils來追加寫入Excel文件示例

    這篇文章主要介紹了Python中使用第三方庫xlutils來追加寫入Excel文件示例,本文直接給出追加寫入示例和追加效果,需要的朋友可以參考下
    2015-04-04
  • Python基于DFA算法實現(xiàn)內(nèi)容敏感詞過濾

    Python基于DFA算法實現(xiàn)內(nèi)容敏感詞過濾

    DFA?算法是通過提前構(gòu)造出一個?樹狀查找結(jié)構(gòu),之后根據(jù)輸入在該樹狀結(jié)構(gòu)中就可以進(jìn)行非常高效的查找。本文將利用改算法實現(xiàn)敏感詞過濾,需要的可以參考一下
    2022-04-04
  • python保存圖片時如何和原圖大小一致

    python保存圖片時如何和原圖大小一致

    這篇文章主要介紹了python保存圖片時如何和原圖大小一致問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-11-11
  • Python數(shù)據(jù)分析numpy數(shù)組的3種創(chuàng)建方式

    Python數(shù)據(jù)分析numpy數(shù)組的3種創(chuàng)建方式

    這篇文章主要介紹了Python數(shù)據(jù)分析numpy數(shù)組的3種創(chuàng)建方式,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價值,需要的朋友可以參考一下
    2022-07-07
  • Python 加密的實例詳解

    Python 加密的實例詳解

    這篇文章主要介紹了 Python 加密的實例詳解的相關(guān)資料,這里提供了兩種實現(xiàn)方法,需要的朋友可以參考下
    2017-10-10
  • OpenCV機(jī)器學(xué)習(xí)MeanShift算法筆記分享

    OpenCV機(jī)器學(xué)習(xí)MeanShift算法筆記分享

    這篇文章主要介紹了OpenCV機(jī)器學(xué)習(xí)MeanShift算法筆記分享,有需要的朋友可以借鑒參考下,希望可以對各位讀者的OpenCV算法學(xué)習(xí)能夠有所幫助
    2021-09-09

最新評論