Keras-多輸入多輸出實例(多任務)
1、模型結(jié)果設計
2、代碼
from keras import Input, Model from keras.layers import Dense, Concatenate import numpy as np from keras.utils import plot_model from numpy import random as rd samples_n = 3000 samples_dim_01 = 2 samples_dim_02 = 2 # 樣本數(shù)據(jù) x1 = rd.rand(samples_n, samples_dim_01) x2 = rd.rand(samples_n, samples_dim_02) y_1 = [] y_2 = [] y_3 = [] for x11, x22 in zip(x1, x2): y_1.append(np.sum(x11) + np.sum(x22)) y_2.append(np.max([np.max(x11), np.max(x22)])) y_3.append(np.min([np.min(x11), np.min(x22)])) y_1 = np.array(y_1) y_1 = np.expand_dims(y_1, axis=1) y_2 = np.array(y_2) y_2 = np.expand_dims(y_2, axis=1) y_3 = np.array(y_3) y_3 = np.expand_dims(y_3, axis=1) # 輸入層 inputs_01 = Input((samples_dim_01,), name='input_1') inputs_02 = Input((samples_dim_02,), name='input_2') # 全連接層 dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01) dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01) dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02) # 加入合并層 merge = Concatenate()([dense_011, dense_02]) # 分成兩類輸出 --- 輸出01 output_01 = Dense(units=6, activation="relu", name='output01')(merge) output_011 = Dense(units=1, activation=None, name='output011')(output_01) # 分成兩類輸出 --- 輸出02 output_02 = Dense(units=1, activation=None, name='output02')(merge) # 分成兩類輸出 --- 輸出03 output_03 = Dense(units=1, activation=None, name='output03')(merge) # 構(gòu)造一個新模型 model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011, output_02, output_03 ]) # 顯示模型情況 plot_model(model, show_shapes=True) print(model.summary()) # # 編譯 # model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1, # 0.8, # 0.8 # ]) # # 訓練 # model.fit([x1, x2], [y_1, # y_2, # y_3 # ], epochs=50, batch_size=32, validation_split=0.1) # 以下的方法可靈活設置 model.compile(optimizer='adam', loss={'output011': 'mean_squared_error', 'output02': 'mean_squared_error', 'output03': 'mean_squared_error'}, loss_weights={'output011': 1, 'output02': 0.8, 'output03': 0.8}) model.fit({'input_1': x1, 'input_2': x2}, {'output011': y_1, 'output02': y_2, 'output03': y_3}, epochs=50, batch_size=32, validation_split=0.1) # 預測 test_x1 = rd.rand(1, 2) test_x2 = rd.rand(1, 2) test_y = model.predict(x=[test_x1, test_x2]) # 測試 print("測試結(jié)果:") print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))
補充知識:Keras多輸出(多任務)如何設置fit_generator
在使用Keras的時候,因為需要考慮到效率問題,需要修改fit_generator來適應多輸出
# create model model = Model(inputs=x_inp, outputs=[main_pred, aux_pred]) # complie model model.compile( optimizer=optimizers.Adam(lr=learning_rate), loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)}, loss_weights={"main": 0.5, "auxiliary": 0.5}, metrics=[metrics.binary_accuracy], ) # Train model model.fit_generator( train_gen, epochs=num_epochs, verbose=0, shuffle=True )
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either
a tuple (inputs, targets)
a tuple (inputs, targets, sample_weights).
Keras設計多輸出(多任務)使用fit_generator的步驟如下:
根據(jù)官方文檔,定義一個generator或者一個class繼承Sequence
class Batch_generator(Sequence): """ 用于產(chǎn)生batch_1, batch_2(記住是numpy.array格式轉(zhuǎn)換) """ y_batch = {'main':batch_1,'auxiliary':batch_2} return X_batch, y_batch # or in another way def batch_generator(): """ 用于產(chǎn)生batch_1, batch_2(記住是numpy.array格式轉(zhuǎn)換) """ yield X_batch, {'main': batch_1,'auxiliary':batch_2}
重要的事情說三遍(親自采坑,搜了一大圈才發(fā)現(xiàn)滴):
如果是多輸出(多任務)的時候,這里的target是字典類型
如果是多輸出(多任務)的時候,這里的target是字典類型
如果是多輸出(多任務)的時候,這里的target是字典類型
以上這篇Keras-多輸入多輸出實例(多任務)就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn)TCP協(xié)議下的端口映射功能的腳本程序示例
端口映射一個最基本的運作形態(tài)就是通過一個中間端口將一個端口發(fā)送的數(shù)據(jù)全部轉(zhuǎn)給另一個端口,well,這里我們就來看一下Python實現(xiàn)TCP協(xié)議下的端口映射功能的腳本程序示例2016-06-06如何用python獲取EXCEL文件內(nèi)容并保存到DBC
很多時候,使用python進行數(shù)據(jù)分析的第一步就是讀取excel文件,下面這篇文章主要給大家介紹了關(guān)于如何用python獲取EXCEL文件內(nèi)容并保存到DBC的相關(guān)資料,需要的朋友可以參考2023-12-12PyTorch讀取Cifar數(shù)據(jù)集并顯示圖片的實例講解
今天小編就為大家分享一篇PyTorch讀取Cifar數(shù)據(jù)集并顯示圖片的實例講解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07Python的Flask項目中獲取請求用戶IP地址 addr問題
這篇文章主要介紹了Python的Flask項目中獲取請求用戶IP地址 addr問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-01-01python實現(xiàn)數(shù)據(jù)清洗(缺失值與異常值處理)
今天小編就為大家分享一篇python實現(xiàn)數(shù)據(jù)清洗(缺失值與異常值處理),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12PyTorch零基礎(chǔ)入門之構(gòu)建模型基礎(chǔ)
PyTorch是一個開源的Python機器學習庫,基于Torch,用于自然語言處理等應用程序,它是一個可續(xù)計算包,提供兩個高級功能:1、具有強大的GPU加速的張量計算(如NumPy)。2、包含自動求導系統(tǒng)的深度神經(jīng)網(wǎng)絡2021-10-10