C++ TensorflowLite模型驗(yàn)證的過(guò)程詳解
故事是這樣的:
有一個(gè)手撐檢測(cè)的tflite模型,需要在開(kāi)發(fā)板上跑起來(lái)。手機(jī)版本的已成熟,要移植到開(kāi)發(fā)板上?,F(xiàn)在要驗(yàn)證tflite模型文件在板子上的運(yùn)行結(jié)果要和手機(jī)上一致。
前提:為了多次重復(fù)測(cè)試,在Android端使用了同一幀數(shù)據(jù)(從一個(gè)錄制的mp4中固定取一張圖)測(cè)試代碼如下圖
下面是測(cè)試過(guò)程
記錄下Android版API運(yùn)行推理前的圖片數(shù)據(jù)文件(經(jīng)過(guò)了規(guī)一化處理,所以都是-1~1之間的float數(shù)據(jù))
這一步卡在了寫(xiě)float數(shù)據(jù)到二進(jìn)制文件中,C++讀出來(lái)有問(wèn)題
換了個(gè)方案,直接存儲(chǔ)float字符串
private void saveFile(float[] pfImageData) { try { File file = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS).getAbsolutePath() + "/tfimg"); StringBuilder sb = new StringBuilder(); for (float val : pfImageData) { //保留4位小數(shù),這里可以改為其他值 sb.append(String.format("%.4f", val)); sb.append("\r\n"); } FileWriter out = new FileWriter(file); //文件寫(xiě)入流 out.write(sb.toString()); out.close(); } catch (Exception e) { e.printStackTrace(); Log.e("Melon", "存儲(chǔ)文件異常," + e.getMessage()); } }
拿著這個(gè)文件在板子上輸入到Tflite模型中
測(cè)試代碼,主要是RunInference()和read_file()
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/examples/label_image/label_image.h" #include <fcntl.h> // NOLINT(build/include_order) #include <getopt.h> // NOLINT(build/include_order) #include <sys/time.h> // NOLINT(build/include_order) #include <sys/types.h> // NOLINT(build/include_order) #include <sys/uio.h> // NOLINT(build/include_order) #include <unistd.h> // NOLINT(build/include_order) #include <cstdarg> #include <cstdio> #include <cstdlib> #include <fstream> #include <iomanip> #include <iostream> #include <map> #include <memory> #include <sstream> #include <string> #include <unordered_set> #include <vector> #include "absl/memory/memory.h" #include "tensorflow/lite/examples/label_image/bitmap_helpers.h" #include "tensorflow/lite/examples/label_image/get_top_n.h" #include "tensorflow/lite/examples/label_image/log.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/optional_debug_tools.h" #include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/command_line_flags.h" #include "tensorflow/lite/tools/delegates/delegate_provider.h" namespace tflite { namespace label_image { double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr; using ProvidedDelegateList = tflite::tools::ProvidedDelegateList; class DelegateProviders { public: DelegateProviders() : delegate_list_util_(¶ms_) { delegate_list_util_.AddAllDelegateParams(); } // Initialize delegate-related parameters from parsing command line arguments, // and remove the matching arguments from (*argc, argv). Returns true if all // recognized arg values are parsed correctly. bool InitFromCmdlineArgs(int *argc, const char **argv) { std::vector<tflite::Flag> flags; // delegate_list_util_.AppendCmdlineFlags(&flags); const bool parse_result = Flags::Parse(argc, argv, flags); if (!parse_result) { std::string usage = Flags::Usage(argv[0], flags); LOG(ERROR) << usage; } return parse_result; } // According to passed-in settings `s`, this function sets corresponding // parameters that are defined by various delegate execution providers. See // lite/tools/delegates/README.md for the full list of parameters defined. void MergeSettingsIntoParams(const Settings &s) { // Parse settings related to GPU delegate. // Note that GPU delegate does support OpenCL. 'gl_backend' was introduced // when the GPU delegate only supports OpenGL. Therefore, we consider // setting 'gl_backend' to true means using the GPU delegate. if (s.gl_backend) { if (!params_.HasParam("use_gpu")) { LOG(WARN) << "GPU deleate execution provider isn't linked or GPU " "delegate isn't supported on the platform!"; } else { params_.Set<bool>("use_gpu", true); // The parameter "gpu_inference_for_sustained_speed" isn't available for // iOS devices. if (params_.HasParam("gpu_inference_for_sustained_speed")) { params_.Set<bool>("gpu_inference_for_sustained_speed", true); } params_.Set<bool>("gpu_precision_loss_allowed", s.allow_fp16); } } // Parse settings related to NNAPI delegate. if (s.accel) { if (!params_.HasParam("use_nnapi")) { LOG(WARN) << "NNAPI deleate execution provider isn't linked or NNAPI " "delegate isn't supported on the platform!"; } else { params_.Set<bool>("use_nnapi", true); params_.Set<bool>("nnapi_allow_fp16", s.allow_fp16); } } // Parse settings related to Hexagon delegate. if (s.hexagon_delegate) { if (!params_.HasParam("use_hexagon")) { LOG(WARN) << "Hexagon deleate execution provider isn't linked or " "Hexagon delegate isn't supported on the platform!"; } else { params_.Set<bool>("use_hexagon", true); params_.Set<bool>("hexagon_profiling", s.profiling); } } // Parse settings related to XNNPACK delegate. if (s.xnnpack_delegate) { if (!params_.HasParam("use_xnnpack")) { LOG(WARN) << "XNNPACK deleate execution provider isn't linked or " "XNNPACK delegate isn't supported on the platform!"; } else { params_.Set<bool>("use_xnnpack", true); params_.Set<bool>("num_threads", s.number_of_threads); } } } // Create a list of TfLite delegates based on what have been initialized (i.e. // 'params_'). std::vector<ProvidedDelegateList::ProvidedDelegate> CreateAllDelegates() const { return delegate_list_util_.CreateAllRankedDelegates(); } private: // Contain delegate-related parameters that are initialized from command-line // flags. tflite::tools::ToolParams params_; // A helper to create TfLite delegates. ProvidedDelegateList delegate_list_util_; }; // Takes a file name, and loads a list of labels from it, one per line, and // returns a vector of the strings. It pads with empty strings so the length // of the result is a multiple of 16, because our model expects that. // std::vector<uint8_t> read_file(const std::string &input_bmp_name) // { // int begin, end; // std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); // if (!file) // { // LOG(FATAL) << "input file " << input_bmp_name << " not found"; // exit(-1); // } // begin = file.tellg(); // file.seekg(0, std::ios::end); // end = file.tellg(); // size_t len = end - begin; // LOG(INFO) << "len: " << len; // std::vector<uint8_t> img_bytes(len); // file.seekg(0, std::ios::beg); // file.read(reinterpret_cast<char *>(img_bytes.data()), len); // return img_bytes; // } /** * 讀取文件 */ std::vector<float> read_file(const std::string &input_bmp_name) { int begin, end; std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); if (!file) { LOG(FATAL) << "input file " << input_bmp_name << " not found"; exit(-1); } begin = file.tellg(); file.seekg(0, std::ios::end); end = file.tellg(); size_t len = end - begin; LOG(INFO) << "len: " << len; std::vector<float> img_bytes; file.seekg(0, std::ios::beg); string strLine = ""; float temp; while (getline(file, strLine)) { temp = atof(strLine.c_str()); img_bytes.push_back(temp); } LOG(INFO) << "文件讀取完成:" << input_bmp_name; return img_bytes; } /** * 運(yùn)行推理 */ void RunInference(Settings *settings) { if (!settings->model_name.c_str()) { LOG(ERROR) << "no model file name"; exit(-1); } std::unique_ptr<tflite::FlatBufferModel> model; std::unique_ptr<tflite::Interpreter> interpreter; model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str()); if (!model) { LOG(ERROR) << "Failed to mmap model " << settings->model_name; exit(-1); } settings->model = model.get(); LOG(INFO) << "Loaded model " << settings->model_name; model->error_reporter(); LOG(INFO) << "resolved reporter"; tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*model, resolver)(&interpreter); //生成interpreter if (!interpreter) { LOG(ERROR) << "Failed to construct interpreter"; exit(-1); } interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16); if (settings->verbose) { LOG(INFO) << "tensors size: " << interpreter->tensors_size(); LOG(INFO) << "nodes size: " << interpreter->nodes_size(); LOG(INFO) << "inputs: " << interpreter->inputs().size(); LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0); int t_size = interpreter->tensors_size(); for (int i = 0; i < t_size; i++) { if (interpreter->tensor(i)->name) LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " << interpreter->tensor(i)->bytes << ", " << interpreter->tensor(i)->type << ", " << interpreter->tensor(i)->params.scale << ", " << interpreter->tensor(i)->params.zero_point; } } if (settings->number_of_threads != -1) { interpreter->SetNumThreads(settings->number_of_threads); } int image_width = 128; int image_height = 128; int image_channels = 3; // std::vector<uint8_t> in = read_bmp(settings->input_bmp_name, &image_width, &image_height, &image_channels, settings); std::vector<float> file_bytes = read_file(settings->input_bmp_name); for (int i = 0; i < 100; i++) { //和Android的輸入做對(duì)比 LOG(INFO) << i << ": " << file_bytes[i]; } /* inputs()[0]得到輸入張量數(shù)組中的第一個(gè)張量,也就是classifier中唯一的那個(gè)輸入張量; input是個(gè)整型值,是張量列表中的引索 */ int input = interpreter->inputs()[0]; LOG(INFO) << "input: " << input; const std::vector<int> inputs = interpreter->inputs(); const std::vector<int> outputs = interpreter->outputs(); LOG(INFO) << "number of inputs: " << inputs.size(); LOG(INFO) << "input index: " << inputs[0]; LOG(INFO) << "number of outputs: " << outputs.size(); LOG(INFO) << "outputs index1: " << outputs[0] << ",outputs index2: " << outputs[1]; if (interpreter->AllocateTensors() != kTfLiteOk) { //加載所有tensor LOG(ERROR) << "Failed to allocate tensors!"; exit(-1); } if (settings->verbose) PrintInterpreterState(interpreter.get()); // 從輸入張量的原數(shù)據(jù)中得到輸入尺寸 TfLiteIntArray *dims = interpreter->tensor(input)->dims; int wanted_height = dims->data[1]; int wanted_width = dims->data[2]; int wanted_channels = dims->data[3]; settings->input_type = interpreter->tensor(input)->type; //typed_tensor返回一個(gè)經(jīng)過(guò)固定數(shù)據(jù)類型轉(zhuǎn)換的tensor指針 //以input為索引,在TfLiteTensor* content_.tensors這個(gè)張量表得到具體的張量 //返回該張量的data.raw,它指示張量正關(guān)聯(lián)著的內(nèi)存塊 // resize<float>(interpreter->typed_tensor<float>(input), in.data(), // image_height, image_width, image_channels, wanted_height, // wanted_width, wanted_channels, settings); //賦值給input tensor float *inputP = interpreter->typed_input_tensor<float>(0); LOG(INFO) << "file_bytes size: " << file_bytes.size(); for (int i = 0; i < file_bytes.size(); i++) { inputP[i] = file_bytes[i]; } struct timeval start_time, stop_time; gettimeofday(&start_time, nullptr); for (int i = 0; i < settings->loop_count; i++) { //調(diào)用模型進(jìn)行推理 if (interpreter->Invoke() != kTfLiteOk) { LOG(ERROR) << "Failed to invoke tflite!"; exit(-1); } } gettimeofday(&stop_time, nullptr); LOG(INFO) << "invoked"; LOG(INFO) << "average time: " << (get_us(stop_time) - get_us(start_time)) / (settings->loop_count * 1000) << " ms"; const float threshold = 0.001f; int output = interpreter->outputs()[1]; LOG(INFO) << "output: " << output; LOG(INFO) << "interpreter->tensors_size: " << interpreter->tensors_size(); TfLiteTensor *tensor = interpreter->tensor(output); TfLiteIntArray *output_dims = tensor->dims; // assume output dims to be something like (1, 1, ... ,size) auto output_size = output_dims->data[output_dims->size - 1]; LOG(INFO) << "索引為" << output << "的輸出張量的-" << "output_size: " << output_size; for (int i = 0; i < output_dims->size; i++) { LOG(INFO) << "元數(shù)據(jù)有:" << output_dims->data[i]; } float *prediction = interpreter->typed_output_tensor<float>(1); float classificators[1][896][1]; memcpy(classificators, prediction, 896 * 1 * sizeof(float)); // float classificators[1][896][18]; // memcpy(classificators, prediction, 896 * 18 * sizeof(float)); //輸出分類結(jié)果 for (float(&r)[896][1] : classificators) { for (float(&p)[1] : r) { for (float &q : p) { std::cout << q << ' '; } std::cout << std::endl; } std::cout << std::endl; } } int Main(int argc, char **argv) { DelegateProviders delegate_providers; bool parse_result = delegate_providers.InitFromCmdlineArgs( &argc, const_cast<const char **>(argv)); if (!parse_result) { return EXIT_FAILURE; } Settings s; int c; while (true) { static struct option long_options[] = { {"accelerated", required_argument, nullptr, 'a'}, {"allow_fp16", required_argument, nullptr, 'f'}, {"count", required_argument, nullptr, 'c'}, {"verbose", required_argument, nullptr, 'v'}, {"image", required_argument, nullptr, 'i'}, {"labels", required_argument, nullptr, 'l'}, {"tflite_model", required_argument, nullptr, 'm'}, {"profiling", required_argument, nullptr, 'p'}, {"threads", required_argument, nullptr, 't'}, {"input_mean", required_argument, nullptr, 'b'}, {"input_std", required_argument, nullptr, 's'}, {"num_results", required_argument, nullptr, 'r'}, {"max_profiling_buffer_entries", required_argument, nullptr, 'e'}, {"warmup_runs", required_argument, nullptr, 'w'}, {"gl_backend", required_argument, nullptr, 'g'}, {"hexagon_delegate", required_argument, nullptr, 'j'}, {"xnnpack_delegate", required_argument, nullptr, 'x'}, {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; c = getopt_long(argc, argv, "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:x:", long_options, &option_index); /* Detect the end of the options. */ if (c == -1) break; switch (c) { case 'a': s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'b': s.input_mean = strtod(optarg, nullptr); break; case 'c': s.loop_count = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'e': s.max_profiling_buffer_entries = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'f': s.allow_fp16 = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'g': s.gl_backend = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'i': s.input_bmp_name = optarg; break; case 'j': s.hexagon_delegate = optarg; break; case 'l': s.labels_file_name = optarg; break; case 'm': s.model_name = optarg; break; case 'p': s.profiling = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'r': s.number_of_results = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 's': s.input_std = strtod(optarg, nullptr); break; case 't': s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) optarg, nullptr, 10); break; case 'v': s.verbose = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'w': s.number_of_warmup_runs = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'x': s.xnnpack_delegate = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'h': case '?': /* getopt_long already printed an error message. */ exit(-1); default: exit(-1); } } delegate_providers.MergeSettingsIntoParams(s); RunInference(&s); return 0; } } // namespace label_image } // namespace tflite int main(int argc, char **argv) { return tflite::label_image::Main(argc, argv); }
運(yùn)行指令 ./ws_app --tflite_model libnewpalm_detection.tflite --image tfimg對(duì)比推理前的輸入一致
Android端
開(kāi)發(fā)板上
對(duì)比推理后的輸出一致 Android端
開(kāi)發(fā)板端
到此這篇關(guān)于C++ TensorflowLite模型驗(yàn)證的文章就介紹到這了,更多相關(guān)C++ TensorflowLite模型驗(yàn)證內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
C語(yǔ)言實(shí)現(xiàn)按行讀寫(xiě)文件
這篇文章主要為大家詳細(xì)介紹了C語(yǔ)言實(shí)現(xiàn)按行讀寫(xiě)文件,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-11-11C++11中的可變參數(shù)模板/lambda表達(dá)式
C++11的新特性可變參數(shù)模板能夠讓我們創(chuàng)建可以接受可變參數(shù)的函數(shù)模板和類模板,相比C++98和C++03,類模板和函數(shù)模板中只能含固定數(shù)量的模板參數(shù),可變參數(shù)模板無(wú)疑是一個(gè)巨大的改進(jìn),這篇文章主要介紹了C++11中的可變參數(shù)模板/lambda表達(dá)式,需要的朋友可以參考下2023-03-03