tensorflow之自定義神經(jīng)網(wǎng)絡(luò)層實(shí)例
如下所示:
import tensorflow as tf tfe = tf.contrib.eager tf.enable_eager_execution()
大多數(shù)情況下,在為機(jī)器學(xué)習(xí)模型編寫(xiě)代碼時(shí),您希望在比單個(gè)操作和單個(gè)變量操作更高的抽象級(jí)別上操作。
1.關(guān)于圖層的一些有用操作
許多機(jī)器學(xué)習(xí)模型可以表達(dá)為相對(duì)簡(jiǎn)單的圖層的組合和堆疊,TensorFlow提供了一組許多常用圖層,以及您從頭開(kāi)始或作為組合創(chuàng)建自己的應(yīng)用程序特定圖層的簡(jiǎn)單方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras層在構(gòu)建自己的模型時(shí)非常有用。
#在tf.keras.layers包中,圖層是對(duì)象。要構(gòu)造一個(gè)圖層,只需構(gòu)造一個(gè)對(duì)象。大多數(shù)層將輸出維度/通道的數(shù)量作為第一個(gè)參數(shù)。 layer=tf.keras.layers.Dense(100) #輸入維度的數(shù)量通常是不必要的,因?yàn)樗梢栽诘谝淮问褂脠D層時(shí)推斷出來(lái),但如果您想手動(dòng)指定它,則可以提供它,這在某些復(fù)雜模型中很有用。 layer=tf.keras.layers.Dense(10,input_shape=(None,5)) #調(diào)用層 layer(tf.zeros([10,5])) #圖層有許多有用的方法。例如,您可以通過(guò)調(diào)用layer.variables來(lái)檢查圖層中的所有變量。在這種情況下,完全連接的層將具有權(quán)重和偏差的變量。 variable=layer.variables # variable[0] layer.kernel.numpy() layer.bias
2.自定義圖層
實(shí)現(xiàn)自己的層的最佳方法是擴(kuò)展tf.keras.Layer類并實(shí)現(xiàn):
__init__,您可以在其中執(zhí)行所有與輸入無(wú)關(guān)的初始化
build方法,您知道輸入張量的形狀,并可以進(jìn)行其余的初始化
call方法,在這里進(jìn)行正向傳播計(jì)算
請(qǐng)注意,您不必等到調(diào)用build來(lái)創(chuàng)建變量,您也可以在__init__中創(chuàng)建它們。但是,在build中創(chuàng)建它們的優(yōu)點(diǎn)是它可以根據(jù)圖層將要操作的輸入的形狀啟用后期變量創(chuàng)建。另一方面,在__init__中創(chuàng)建變量意味著需要明確指定創(chuàng)建變量所需的形狀。
class MyDenseLayer(tf.keras.layers.Layer): def __init__(self, num_outputs): super(MyDenseLayer, self).__init__() self.num_outputs = num_outputs def build(self, input_shape): self.kernel = self.add_variable("kernel", shape=[input_shape[-1].value, self.num_outputs]) def call(self, input): return tf.matmul(input, self.kernel) layer = MyDenseLayer(10) print(layer(tf.zeros([10, 5]))) print(layer.variables)
3.搭建網(wǎng)絡(luò)結(jié)構(gòu)
機(jī)器學(xué)習(xí)模型中許多有趣的圖層是通過(guò)組合現(xiàn)有層來(lái)實(shí)現(xiàn)的。例如,resnet中的每個(gè)residual塊是卷積,批量標(biāo)準(zhǔn)化等的組合。
創(chuàng)建包含其他圖層的類似圖層的東西時(shí)使用的主類是tf.keras.Model。實(shí)現(xiàn)一個(gè)是通過(guò)繼承自tf.keras.Model完成的。
class ResnetIdentityBlock(tf.keras.Model): def __init__(self, kernel_size, filters): super(ResnetIdentityBlock, self).__init__(name='') filters1, filters2, filters3 = filters self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1)) self.bn2a = tf.keras.layers.BatchNormalization() self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same') self.bn2b = tf.keras.layers.BatchNormalization() self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1)) self.bn2c = tf.keras.layers.BatchNormalization() def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) x = self.bn2a(x, training=training) x = tf.nn.relu(x) x = self.conv2b(x) x = self.bn2b(x, training=training) x = tf.nn.relu(x) x = self.conv2c(x) x = self.bn2c(x, training=training) x += input_tensor return tf.nn.relu(x) block = ResnetIdentityBlock(1, [1, 2, 3]) print(block(tf.zeros([1, 2, 3, 3]))) print([x.name for x in block.variables])
以上這篇tensorflow之自定義神經(jīng)網(wǎng)絡(luò)層實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)提取圖片中顏色并繪制成可視化圖表
今天小編來(lái)為大家分享一個(gè)有趣的可視化技巧,就是如何利用Python語(yǔ)言實(shí)現(xiàn)從圖片中提取顏色然后繪制成可視化圖表,感興趣的可以嘗試一下2022-07-07淺談Python中進(jìn)程的創(chuàng)建與結(jié)束
這篇文章主要介紹了淺談Python中進(jìn)程的創(chuàng)建與結(jié)束,但凡是硬件,都需要有操作系統(tǒng)去管理,只要有操作系統(tǒng),就有進(jìn)程的概念,就需要有創(chuàng)建進(jìn)程的方式,需要的朋友可以參考下2023-07-07解決pycharm臨時(shí)打包32位程序的問(wèn)題
這篇文章主要介紹了解決pycharm臨時(shí)打包32位程序的問(wèn)題,本文通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-04-04Python2.7簡(jiǎn)單連接與操作MySQL的方法
這篇文章主要介紹了Python2.7簡(jiǎn)單連接與操作MySQL的方法,涉及Python使用MySQLdb模塊操作MySQL連接及命令運(yùn)行的相關(guān)技巧,需要的朋友可以參考下2016-04-04Django 后臺(tái)帶有字典的列表數(shù)據(jù)與頁(yè)面js交互實(shí)例
這篇文章主要介紹了Django 后臺(tái)帶有字典的列表數(shù)據(jù)與頁(yè)面js交互實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-04-04Python實(shí)現(xiàn)動(dòng)態(tài)添加類的屬性或成員函數(shù)的解決方法
這篇文章主要介紹了Python實(shí)現(xiàn)動(dòng)態(tài)添加類的屬性或成員函數(shù)的解決方法,在類似插件開(kāi)發(fā)的時(shí)候會(huì)比較有用,需要的朋友可以參考下2014-07-07淺談Python使用Bottle來(lái)提供一個(gè)簡(jiǎn)單的web服務(wù)
這篇文章主要介紹了淺談Python使用Bottle來(lái)提供一個(gè)簡(jiǎn)單的web服務(wù),具有一定借鑒價(jià)值,需要的朋友可以參考下2017-12-12