解決tensorflow 與keras 混用之坑
在使用tensorflow與keras混用是model.save 是正常的但是在load_model的時(shí)候報(bào)錯(cuò)了在這里mark 一下
其中錯(cuò)誤為:TypeError: tuple indices must be integers, not list
再一一番百度后無(wú)結(jié)果,上谷歌后找到了類(lèi)似的問(wèn)題。但是是一對(duì)鳥(niǎo)文不知道什么東西(翻譯后發(fā)現(xiàn)是俄文)。后來(lái)谷歌翻譯了一下找到了解決方法。故將原始問(wèn)題文章貼上來(lái)警示一下
原訓(xùn)練代碼
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator from tensorflow.python.keras.models import Sequential from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense #Каталог с данными для обучения train_dir = 'train' # Каталог с данными для проверки val_dir = 'val' # Каталог с данными для тестирования test_dir = 'val' # Размеры изображения img_width, img_height = 800, 800 # Размерность тензора на основе изображения для входных данных в нейронную сеть # backend Tensorflow, channels_last input_shape = (img_width, img_height, 3) # Количество эпох epochs = 1 # Размер мини-выборки batch_size = 4 # Количество изображений для обучения nb_train_samples = 300 # Количество изображений для проверки nb_validation_samples = 25 # Количество изображений для тестирования nb_test_samples = 25 model = Sequential() model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape)) model.add(BatchNormalization()) model.add(Activation('tanh')) model.add(MaxPooling2D(pool_size=(10, 10))) model.add(Conv2D(64, (5, 5), padding="same")) model.add(BatchNormalization()) model.add(Activation('tanh')) model.add(MaxPooling2D(pool_size=(10, 10))) model.add(Flatten()) model.add(Dense(512)) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer="Nadam", metrics=['accuracy']) print(model.summary()) datagen = ImageDataGenerator(rescale=1. / 255) train_generator = datagen.flow_from_directory( train_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') val_generator = datagen.flow_from_directory( val_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') test_generator = datagen.flow_from_directory( test_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') model.fit_generator( train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=epochs, validation_data=val_generator, validation_steps=nb_validation_samples // batch_size) print('Сохраняем сеть') model.save("grib.h5") print("Сохранение завершено!")
模型載入
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator from tensorflow.python.keras.models import Sequential from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense from keras.models import load_model print("Загрузка сети") model = load_model("grib.h5") print("Загрузка завершена!")
報(bào)錯(cuò)
/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
File "/home/disk2/py/neroset/do.py", line 13, in <module>
model = load_model("grib.h5")
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
list(custom_objects.items())))
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
model.add(layer)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
output_tensor = layer(self.outputs[0])
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
self.build(input_shapes[0])
File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
Process finished with exit code 1
戰(zhàn)斗種族解釋
убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(譯文:整理BatchNormalization一切正常。 不要告訴我錯(cuò)誤是什么?我發(fā)現(xiàn)保存keras和規(guī)范化tensorflow不能一起工作;只需更改導(dǎo)入字符串即可。)
強(qiáng)調(diào)文本 強(qiáng)調(diào)文本
keras.preprocessing.image import ImageDataGenerator keras.models import Sequential keras.layers import Conv2D, MaxPooling2D, BatchNormalization keras.layers import Activation, Dropout, Flatten, Dense
##完美解決
##附上原文鏈接
https://qa-help.ru/questions/keras-batchnormalization
補(bǔ)充:keras和tensorflow模型同時(shí)讀取要慎重
項(xiàng)目中,先讀取了一個(gè)keras模型獲取模型輸入size,再加載keras轉(zhuǎn)tensorflow后的pb模型進(jìn)行預(yù)測(cè)。
報(bào)錯(cuò):
Attempting to use uninitialized value batch_normalization_14/moving_mean
逛論壇,有建議加上初始化:
sess.run(tf.global_variables_initializer())
但是這樣的話,會(huì)導(dǎo)致模型參數(shù)全部變成初始化數(shù)據(jù)。無(wú)法使用預(yù)測(cè)模型參數(shù)。
最后發(fā)現(xiàn),將keras模型的加載去掉即可。
猜測(cè)原因:keras模型和tensorflow模型同時(shí)讀取有坑
import cv2 import numpy as np from keras.models import load_model from utils.datasets import get_labels from utils.preprocessor import preprocess_input import time import os import tensorflow as tf from tensorflow.python.platform import gfile os.environ["CUDA_VISIBLE_DEVICES"] = "-1" emotion_labels = get_labels('fer2013') emotion_target_size = (64,64) #emotion_model_path = './models/emotion_model.hdf5' #emotion_classifier = load_model(emotion_model_path) #emotion_target_size = emotion_classifier.input_shape[1:3] path = '/mnt/nas/cv_data/emotion/test' filelist = os.listdir(path) total_num = len(filelist) timeall = 0 n = 0 sess = tf.Session() #sess.run(tf.global_variables_initializer()) with gfile.FastGFile("./trans_model/emotion_mode.pb", 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') pred = sess.graph.get_tensor_by_name("predictions/Softmax:0") ######################img########################## for item in filelist: if (item == '.DS_Store') | (item == 'Thumbs.db'): continue src = os.path.join(os.path.abspath(path), item) bgr_image = cv2.imread(src) gray_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2GRAY) gray_face = gray_image try: gray_face = cv2.resize(gray_face, (emotion_target_size)) except: continue gray_face = preprocess_input(gray_face, True) gray_face = np.expand_dims(gray_face, 0) gray_face = np.expand_dims(gray_face, -1) input = sess.graph.get_tensor_by_name('input_1:0') res = sess.run(pred, {input: gray_face}) print("src:", src) emotion_probability = np.max(res[0]) emotion_label_arg = np.argmax(res[0]) emotion_text = emotion_labels[emotion_label_arg] print("predict:", res[0], ",prob:", emotion_probability, ",label:", emotion_label_arg, ",text:",emotion_text)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python如何實(shí)現(xiàn)動(dòng)態(tài)數(shù)組
這篇文章主要介紹了Python如何實(shí)現(xiàn)動(dòng)態(tài)數(shù)組,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-11-11pandas實(shí)現(xiàn)滑動(dòng)窗口的示例代碼
本文主要介紹了pandas實(shí)現(xiàn)滑動(dòng)窗口的示例代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03python安裝pandas庫(kù)不成功原因分析及解決辦法
Pandas是python中非常常用的數(shù)據(jù)分析庫(kù),在數(shù)據(jù)分析、機(jī)器學(xué)習(xí)、深度學(xué)習(xí)等領(lǐng)域經(jīng)常被使用,下面這篇文章主要給大家介紹了關(guān)于python安裝pandas庫(kù)不成功原因分析及解決辦法的相關(guān)資料2023-11-11Numpy中轉(zhuǎn)置transpose、T和swapaxes的實(shí)例講解
下面小編就為大家分享一篇Numpy中轉(zhuǎn)置transpose、T和swapaxes的實(shí)例講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04Python實(shí)現(xiàn)Linux監(jiān)控的方法
本文通過(guò)實(shí)例代碼給大家介紹了Python實(shí)現(xiàn)Linux監(jiān)控的方法,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-05-05Django2 連接MySQL及model測(cè)試實(shí)例分析
這篇文章主要介紹了Django2 連接MySQL及model測(cè)試,結(jié)合實(shí)例形式分析了Django2框架使用pymysql庫(kù)進(jìn)行mysql數(shù)據(jù)庫(kù)連接與model調(diào)用測(cè)試方法,需要的朋友可以參考下2019-12-12Python判斷值是否在list或set中的性能對(duì)比分析
這篇文章主要介紹了Python判斷值是否在list或set中的性能對(duì)比分析,結(jié)合實(shí)例形式對(duì)比分析了使用list與set循環(huán)的執(zhí)行效率,需要的朋友可以參考下2016-04-04