Pytorch轉(zhuǎn)tflite方式
目標是想把在服務器上用pytorch訓練好的模型轉(zhuǎn)換為可以在移動端運行的tflite模型。
最直接的思路是想把pytorch模型轉(zhuǎn)換為tensorflow的模型,然后轉(zhuǎn)換為tflite。但是這個轉(zhuǎn)換目前沒有發(fā)現(xiàn)比較靠譜的方法。
經(jīng)過調(diào)研發(fā)現(xiàn)最新的tflite已經(jīng)支持直接從keras模型的轉(zhuǎn)換,所以可以采用keras作為中間轉(zhuǎn)換的橋梁,這樣就能充分利用keras高層API的便利性。
轉(zhuǎn)換的基本思想就是用pytorch中的各層網(wǎng)絡的權重取出來后直接賦值給keras網(wǎng)絡中的對應layer層的權重。
轉(zhuǎn)換為Keras模型后,再通過tf.contrib.lite.TocoConverter把模型直接轉(zhuǎn)為tflite.
下面是一個例子,假設轉(zhuǎn)換的是一個兩層的CNN網(wǎng)絡。
import tensorflow as tf
from tensorflow import keras
import numpy as np
import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable
class PytorchNet(nn.Module):
def __init__(self):
super(PytorchNet, self).__init__()
conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, 1, groups=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
self.feature = nn.Sequential(conv1, conv2)
self.init_weights()
def forward(self, x):
return self.feature(x)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight.data, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def KerasNet(input_shape=(224, 224, 3)):
image_input = keras.layers.Input(shape=input_shape)
# conv1
network = keras.layers.Conv2D(
32, (3, 3), strides=(2, 2), padding="valid")(image_input)
network = keras.layers.BatchNormalization(
trainable=False, fused=False)(network)
network = keras.layers.Activation("relu")(network)
network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)
# conv2
network = keras.layers.Conv2D(
64, (3, 3), strides=(1, 1), padding="valid")(network)
network = keras.layers.BatchNormalization(
trainable=False, fused=True)(network)
network = keras.layers.Activation("relu")(network)
network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)
model = keras.Model(inputs=image_input, outputs=network)
return model
class PytorchToKeras(object):
def __init__(self, pModel, kModel):
super(PytorchToKeras, self)
self.__source_layers = []
self.__target_layers = []
self.pModel = pModel
self.kModel = kModel
tf.keras.backend.set_learning_phase(0)
def __retrieve_k_layers(self):
for i, layer in enumerate(self.kModel.layers):
if len(layer.weights) > 0:
self.__target_layers.append(i)
def __retrieve_p_layers(self, input_size):
input = torch.randn(input_size)
input = Variable(input.unsqueeze(0))
hooks = []
def add_hooks(module):
def hook(module, input, output):
if hasattr(module, "weight"):
# print(module)
self.__source_layers.append(module)
if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
hooks.append(module.register_forward_hook(hook))
self.pModel.apply(add_hooks)
self.pModel(input)
for hook in hooks:
hook.remove()
def convert(self, input_size):
self.__retrieve_k_layers()
self.__retrieve_p_layers(input_size)
for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
print(source_layer)
weight_size = len(source_layer.weight.data.size())
transpose_dims = []
for i in range(weight_size):
transpose_dims.append(weight_size - i - 1)
if isinstance(source_layer, nn.Conv2d):
transpose_dims = [2,3,1,0]
self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
).transpose(transpose_dims), source_layer.bias.data.numpy()])
elif isinstance(source_layer, nn.BatchNorm2d):
self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])
def save_model(self, output_file):
self.kModel.save(output_file)
def save_weights(self, output_file):
self.kModel.save_weights(output_file, save_format='h5')
pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))
torch.save(pytorch_model, 'test.pth')
#Load the pretrained model
pytorch_model = torch.load('test.pth')
# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))
# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")
# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
"keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)
補充知識:tensorflow模型轉(zhuǎn)換成tensorflow lite模型
1.把graph和網(wǎng)絡模型打包在一個文件中
bazel build tensorflow/python/tools:freeze_graph && \ bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=eval_graph_def.pb \ --input_checkpoint=checkpoint \ --output_graph=frozen_eval_graph.pb \ --output_node_names=outputs
For example:
bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \ --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \ --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \ --output_node_names=MobilenetV1/Predictions/Reshape_1
2.把第一步中生成的tensorflow pb模型轉(zhuǎn)換為tf lite模型
轉(zhuǎn)換前需要先編譯轉(zhuǎn)換工具
bazel build tensorflow/contrib/lite/toco:toco
轉(zhuǎn)換分兩種,一種的轉(zhuǎn)換為float的tf lite,另一種可以轉(zhuǎn)換為對模型進行unit8的量化版本的模型。兩種方式如下:
非量化的轉(zhuǎn)換:
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官網(wǎng)給的這個路徑不對 ./bazel-bin/tensorflow/contrib/lite/toco/toco \ —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \ —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ --inference_type=FLOAT \ --input_shape="1,224, 224,3" \ --input_array=input \ --output_array=MobilenetV1/Predictions/Reshape_1
量化方式的轉(zhuǎn)換(注意,只有量化訓練的模型才能進行量化的tf_lite轉(zhuǎn)換):
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ ./bazel-bin/tensorflow/contrib/lite/toco/toco \ --input_file=frozen_eval_graph.pb \ --output_file=tflite_model.tflite \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ --inference_type=QUANTIZED_UINT8 \ --input_shape="1,224, 224,3" \ --input_array=input \ --output_array=outputs \ --std_value=127.5 --mean_value=127.5
以上這篇Pytorch轉(zhuǎn)tflite方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
socket + select 完成偽并發(fā)操作的實例
下面小編就為大家?guī)硪黄猻ocket + select 完成偽并發(fā)操作的實例。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-08-08
Python執(zhí)行外部命令subprocess的使用詳解
subeprocess模塊是python自帶的模塊,無需安裝,主要用來取代一些就的模塊或方法,本文通過實例代碼給大家分享Python執(zhí)行外部命令subprocess及使用方法,感興趣的朋友跟隨小編一起看看吧2021-05-05
實現(xiàn)Windows下設置定時任務來運行python腳本
這篇文章主要介紹了實現(xiàn)Windows下設置定時任務來運行python腳本的完整過程,有需要的朋友可以借鑒參考下,希望對廣大讀者朋友能夠有所幫助2021-09-09

