C#使用TensorFlow.NET訓(xùn)練自己的數(shù)據(jù)集的方法
今天,我結(jié)合代碼來(lái)詳細(xì)介紹如何使用 SciSharp STACK 的 TensorFlow.NET 來(lái)訓(xùn)練CNN模型,該模型主要實(shí)現(xiàn) 圖像的分類 ,可以直接移植該代碼在 CPU 或 GPU 下使用,并針對(duì)你們自己本地的圖像數(shù)據(jù)集進(jìn)行訓(xùn)練和推理。TensorFlow.NET是基于 .NET Standard 框架的完整實(shí)現(xiàn)的TensorFlow,可以支持 .NET Framework 或 .NET CORE , TensorFlow.NET 為廣大.NET開發(fā)者提供了完美的機(jī)器學(xué)習(xí)框架選擇。
SciSharp STACK:https://github.com/SciSharp
什么是TensorFlow.NET?
TensorFlow.NET 是 SciSharp STACK

開源社區(qū)團(tuán)隊(duì)的貢獻(xiàn),其使命是打造一個(gè)完全屬于.NET開發(fā)者自己的機(jī)器學(xué)習(xí)平臺(tái),特別對(duì)于C#開發(fā)人員來(lái)說(shuō),是一個(gè)“0”學(xué)習(xí)成本的機(jī)器學(xué)習(xí)平臺(tái),該平臺(tái)集成了大量API和底層封裝,力圖使TensorFlow的Python代碼風(fēng)格和編程習(xí)慣可以無(wú)縫移植到.NET平臺(tái),下圖是同樣TF任務(wù)的Python實(shí)現(xiàn)和C#實(shí)現(xiàn)的語(yǔ)法相似度對(duì)比,從中讀者基本可以略窺一二。

由于TensorFlow.NET在.NET平臺(tái)的優(yōu)秀性能,同時(shí)搭配SciSharp的NumSharp、SharpCV、Pandas.NET、Keras.NET、Matplotlib.Net等模塊,可以完全脫離Python環(huán)境使用,目前已經(jīng)被微軟ML.NET官方的底層算法集成,并被谷歌寫入TensorFlow官網(wǎng)教程推薦給全球開發(fā)者。
SciSharp 產(chǎn)品結(jié)構(gòu)

微軟 ML.NET底層集成算法

谷歌官方推薦.NET開發(fā)者使用
URL: https://www.tensorflow.org/versions/r2.0/api_docs

項(xiàng)目說(shuō)明
本文利用TensorFlow.NET構(gòu)建簡(jiǎn)單的圖像分類模型,針對(duì)工業(yè)現(xiàn)場(chǎng)的印刷字符進(jìn)行單字符OCR識(shí)別,從工業(yè)相機(jī)獲取原始大尺寸的圖像,前期使用OpenCV進(jìn)行圖像預(yù)處理和字符分割,提取出單個(gè)字符的小圖,送入TF進(jìn)行推理,推理的結(jié)果按照順序組合成完整的字符串,返回至主程序邏輯進(jìn)行后續(xù)的生產(chǎn)線工序。
實(shí)際使用中,如果你們需要訓(xùn)練自己的圖像,只需要把訓(xùn)練的文件夾按照規(guī)定的順序替換成你們自己的圖片即可。支持GPU或CPU方式,該項(xiàng)目的完整代碼在GitHub如下:
模型介紹
本項(xiàng)目的CNN模型主要由 2個(gè)卷積層&池化層 和 1個(gè)全連接層 組成,激活函數(shù)使用常見的Relu,是一個(gè)比較淺的卷積神經(jīng)網(wǎng)絡(luò)模型。其中超參數(shù)之一"學(xué)習(xí)率",采用了自定義的動(dòng)態(tài)下降的學(xué)習(xí)率,后面會(huì)有詳細(xì)說(shuō)明。具體每一層的Shape參考下圖:

數(shù)據(jù)集說(shuō)明
為了模型測(cè)試的訓(xùn)練速度考慮,圖像數(shù)據(jù)集主要節(jié)選了一小部分的OCR字符(X、Y、Z),數(shù)據(jù)集的特征如下:
分類數(shù)量:3 classes 【X/Y/Z】
圖像尺寸:Width 64 × Height 64
圖像通道:1 channel(灰度圖)
數(shù)據(jù)集數(shù)量:
- train:X - 384pcs ; Y - 384pcs ; Z - 384pcs
- validation:X - 96pcs ; Y - 96pcs ; Z - 96pcs
- test:X - 96pcs ; Y - 96pcs ; Z - 96pcs
其它說(shuō)明:數(shù)據(jù)集已經(jīng)經(jīng)過(guò) 隨機(jī) 翻轉(zhuǎn)/平移/縮放/鏡像 等預(yù)處理進(jìn)行增強(qiáng)
整體數(shù)據(jù)集情況如下圖所示:



代碼說(shuō)明
環(huán)境設(shè)置
- .NET 框架:使用.NET Framework 4.7.2及以上,或者使用.NET CORE 2.2及以上
- CPU 配置: Any CPU 或 X64 皆可
- GPU 配置:需要自行配置好CUDA和環(huán)境變量,建議 CUDA v10.1,Cudnn v7.5
類庫(kù)和命名空間引用
從NuGet安裝必要的依賴項(xiàng),主要是SciSharp相關(guān)的類庫(kù),如下圖所示:
注意事項(xiàng):盡量安裝最新版本的類庫(kù),CV須使用 SciSharp 的 SharpCV 方便內(nèi)部變量傳遞
<PackageReference Include="Colorful.Console" Version="1.2.9" /> <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> <PackageReference Include="SciSharp.TensorFlow.Redist" Version="1.15.0" /> <PackageReference Include="SciSharp.TensorFlowHub" Version="0.0.5" /> <PackageReference Include="SharpCV" Version="0.2.0" /> <PackageReference Include="SharpZipLib" Version="1.2.0" /> <PackageReference Include="System.Drawing.Common" Version="4.7.0" /> <PackageReference Include="TensorFlow.NET" Version="0.14.0" />
引用命名空間,包括 NumSharp、Tensorflow 和 SharpCV ;
using NumSharp; using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using SharpCV; using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Runtime.CompilerServices; using Tensorflow; using static Tensorflow.Binding; using static SharpCV.Binding; using System.Collections.Concurrent; using System.Threading.Tasks;
主邏輯結(jié)構(gòu)
主邏輯:
準(zhǔn)備數(shù)據(jù)
創(chuàng)建計(jì)算圖
訓(xùn)練
預(yù)測(cè)
public bool Run()
{
PrepareData();
BuildGraph();
using (var sess = tf.Session())
{
Train(sess);
Test(sess);
}
TestDataOutput();
return accuracy_test > 0.98;
}
數(shù)據(jù)集載入
數(shù)據(jù)集下載和解壓
數(shù)據(jù)集地址:https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip
數(shù)據(jù)集下載和解壓代碼 ( 部分封裝的方法請(qǐng)參考 GitHub完整代碼 ):
string url = "https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip"; Directory.CreateDirectory(Name); Utility.Web.Download(url, Name, "data_CnnInYourOwnData.zip"); Utility.Compress.UnZip(Name + "\\data_CnnInYourOwnData.zip", Name);
字典創(chuàng)建
讀取目錄下的子文件夾名稱,作為分類的字典,方便后面One-hot使用
private void FillDictionaryLabel(string DirPath)
{
string[] str_dir = Directory.GetDirectories(DirPath, "*", SearchOption.TopDirectoryOnly);
int str_dir_num = str_dir.Length;
if (str_dir_num > 0)
{
Dict_Label = new Dictionary<Int64, string>();
for (int i = 0; i < str_dir_num; i++)
{
string label = (str_dir[i].Replace(DirPath + "\\", "")).Split('\\').First();
Dict_Label.Add(i, label);
print(i.ToString() + " : " + label);
}
n_classes = Dict_Label.Count;
}
}
文件List讀取和打亂
從文件夾中讀取train、validation、test的list,并隨機(jī)打亂順序。
讀取目錄
ArrayFileName_Train = Directory.GetFiles(Name + "\\train", "*.*", SearchOption.AllDirectories); ArrayLabel_Train = GetLabelArray(ArrayFileName_Train); ArrayFileName_Validation = Directory.GetFiles(Name + "\\validation", "*.*", SearchOption.AllDirectories); ArrayLabel_Validation = GetLabelArray(ArrayFileName_Validation); ArrayFileName_Test = Directory.GetFiles(Name + "\\test", "*.*", SearchOption.AllDirectories); ArrayLabel_Test = GetLabelArray(ArrayFileName_Test);
獲得標(biāo)簽
private Int64[] GetLabelArray(string[] FilesArray)
{
Int64[] ArrayLabel = new Int64[FilesArray.Length];
for (int i = 0; i < ArrayLabel.Length; i++)
{
string[] labels = FilesArray[i].Split('\\');
string label = labels[labels.Length - 2];
ArrayLabel[i] = Dict_Label.Single(k => k.Value == label).Key;
}
return ArrayLabel;
}
隨機(jī)亂序
public (string[], Int64[]) ShuffleArray(int count, string[] images, Int64[] labels)
{
ArrayList mylist = new ArrayList();
string[] new_images = new string[count];
Int64[] new_labels = new Int64[count];
Random r = new Random();
for (int i = 0; i < count; i++)
{
mylist.Add(i);
}
for (int i = 0; i < count; i++)
{
int rand = r.Next(mylist.Count);
new_images[i] = images[(int)(mylist[rand])];
new_labels[i] = labels[(int)(mylist[rand])];
mylist.RemoveAt(rand);
}
print("shuffle array list: " + count.ToString());
return (new_images, new_labels);
}
部分?jǐn)?shù)據(jù)集預(yù)先載入
Validation/Test數(shù)據(jù)集和標(biāo)簽一次性預(yù)先載入成NDArray格式。
private void LoadImagesToNDArray()
{
//Load labels
y_valid = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Validation)];
y_test = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Test)];
print("Load Labels To NDArray : OK!");
//Load Images
x_valid = np.zeros(ArrayFileName_Validation.Length, img_h, img_w, n_channels);
x_test = np.zeros(ArrayFileName_Test.Length, img_h, img_w, n_channels);
LoadImage(ArrayFileName_Validation, x_valid, "validation");
LoadImage(ArrayFileName_Test, x_test, "test");
print("Load Images To NDArray : OK!");
}
private void LoadImage(string[] a, NDArray b, string c)
{
for (int i = 0; i < a.Length; i++)
{
b[i] = ReadTensorFromImageFile(a[i]);
Console.Write(".");
}
Console.WriteLine();
Console.WriteLine("Load Images To NDArray: " + c);
}
private NDArray ReadTensorFromImageFile(string file_name)
{
using (var graph = tf.Graph().as_default())
{
var file_reader = tf.read_file(file_name, "file_reader");
var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: n_channels, name: "DecodeJpeg");
var cast = tf.cast(decodeJpeg, tf.float32);
var dims_expander = tf.expand_dims(cast, 0);
var resize = tf.constant(new int[] { img_h, img_w });
var bilinear = tf.image.resize_bilinear(dims_expander, resize);
var sub = tf.subtract(bilinear, new float[] { img_mean });
var normalized = tf.divide(sub, new float[] { img_std });
using (var sess = tf.Session(graph))
{
return sess.run(normalized);
}
}
}
計(jì)算圖構(gòu)建
構(gòu)建CNN靜態(tài)計(jì)算圖,其中學(xué)習(xí)率每n輪Epoch進(jìn)行1次遞減。
#region BuildGraph
public Graph BuildGraph()
{
var graph = new Graph().as_default();
tf_with(tf.name_scope("Input"), delegate
{
x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X");
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
});
var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1");
var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1");
var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2");
var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2");
var layer_flat = flatten_layer(pool2);
var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true);
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
//Some important parameter saved with graph , easy to load later
var img_h_t = tf.constant(img_h, name: "img_h");
var img_w_t = tf.constant(img_w, name: "img_w");
var img_mean_t = tf.constant(img_mean, name: "img_mean");
var img_std_t = tf.constant(img_std, name: "img_std");
var channels_t = tf.constant(n_channels, name: "img_channels");
//learning rate decay
gloabl_steps = tf.Variable(0, trainable: false);
learning_rate = tf.Variable(learning_rate_base);
//create train images graph
tf_with(tf.variable_scope("LoadImage"), delegate
{
decodeJpeg = tf.placeholder(tf.@byte, name: "DecodeJpeg");
var cast = tf.cast(decodeJpeg, tf.float32);
var dims_expander = tf.expand_dims(cast, 0);
var resize = tf.constant(new int[] { img_h, img_w });
var bilinear = tf.image.resize_bilinear(dims_expander, resize);
var sub = tf.subtract(bilinear, new float[] { img_mean });
normalized = tf.divide(sub, new float[] { img_std }, name: "normalized");
});
tf_with(tf.variable_scope("Train"), delegate
{
tf_with(tf.variable_scope("Loss"), delegate
{
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss");
});
tf_with(tf.variable_scope("Optimizer"), delegate
{
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss, global_step: gloabl_steps);
});
tf_with(tf.variable_scope("Accuracy"), delegate
{
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
});
tf_with(tf.variable_scope("Prediction"), delegate
{
cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
prob = tf.nn.softmax(output_logits, axis: 1, name: "prob");
});
});
return graph;
}
/// <summary>
/// Create a 2D convolution layer
/// </summary>
/// <param name="x">input from previous layer</param>
/// <param name="filter_size">size of each filter</param>
/// <param name="num_filters">number of filters(or output feature maps)</param>
/// <param name="stride">filter stride</param>
/// <param name="name">layer name</param>
/// <returns>The output array</returns>
private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name)
{
return tf_with(tf.variable_scope(name), delegate
{
var num_in_channel = x.shape[x.NDims - 1];
var shape = new[] { filter_size, filter_size, num_in_channel, num_filters };
var W = weight_variable("W", shape);
// var tf.summary.histogram("weight", W);
var b = bias_variable("b", new[] { num_filters });
// tf.summary.histogram("bias", b);
var layer = tf.nn.conv2d(x, W,
strides: new[] { 1, stride, stride, 1 },
padding: "SAME");
layer += b;
return tf.nn.relu(layer);
});
}
/// <summary>
/// Create a max pooling layer
/// </summary>
/// <param name="x">input to max-pooling layer</param>
/// <param name="ksize">size of the max-pooling filter</param>
/// <param name="stride">stride of the max-pooling filter</param>
/// <param name="name">layer name</param>
/// <returns>The output array</returns>
private Tensor max_pool(Tensor x, int ksize, int stride, string name)
{
return tf.nn.max_pool(x,
ksize: new[] { 1, ksize, ksize, 1 },
strides: new[] { 1, stride, stride, 1 },
padding: "SAME",
name: name);
}
/// <summary>
/// Flattens the output of the convolutional layer to be fed into fully-connected layer
/// </summary>
/// <param name="layer">input array</param>
/// <returns>flattened array</returns>
private Tensor flatten_layer(Tensor layer)
{
return tf_with(tf.variable_scope("Flatten_layer"), delegate
{
var layer_shape = layer.TensorShape;
var num_features = layer_shape[new Slice(1, 4)].size;
var layer_flat = tf.reshape(layer, new[] { -1, num_features });
return layer_flat;
});
}
/// <summary>
/// Create a weight variable with appropriate initialization
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <returns></returns>
private RefVariable weight_variable(string name, int[] shape)
{
var initer = tf.truncated_normal_initializer(stddev: 0.01f);
return tf.get_variable(name,
dtype: tf.float32,
shape: shape,
initializer: initer);
}
/// <summary>
/// Create a bias variable with appropriate initialization
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <returns></returns>
private RefVariable bias_variable(string name, int[] shape)
{
var initial = tf.constant(0f, shape: shape, dtype: tf.float32);
return tf.get_variable(name,
dtype: tf.float32,
initializer: initial);
}
/// <summary>
/// Create a fully-connected layer
/// </summary>
/// <param name="x">input from previous layer</param>
/// <param name="num_units">number of hidden units in the fully-connected layer</param>
/// <param name="name">layer name</param>
/// <param name="use_relu">boolean to add ReLU non-linearity (or not)</param>
/// <returns>The output array</returns>
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
{
return tf_with(tf.variable_scope(name), delegate
{
var in_dim = x.shape[1];
var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units });
var b = bias_variable("b_" + name, new[] { num_units });
var layer = tf.matmul(x, W) + b;
if (use_relu)
layer = tf.nn.relu(layer);
return layer;
});
}
#endregion
模型訓(xùn)練和模型保存
Batch數(shù)據(jù)集的讀取,采用了 SharpCV 的cv2.imread,可以直接讀取本地圖像文件至NDArray,實(shí)現(xiàn)CV和Numpy的無(wú)縫對(duì)接;
使用.NET的異步線程安全隊(duì)列BlockingCollection<T>,實(shí)現(xiàn)TensorFlow原生的隊(duì)列管理器FIFOQueue;
在訓(xùn)練模型的時(shí)候,我們需要將樣本從硬盤讀取到內(nèi)存之后,才能進(jìn)行訓(xùn)練。我們?cè)跁?huì)話中運(yùn)行多個(gè)線程,并加入隊(duì)列管理器進(jìn)行線程間的文件入隊(duì)出隊(duì)操作,并限制隊(duì)列容量,主線程可以利用隊(duì)列中的數(shù)據(jù)進(jìn)行訓(xùn)練,另一個(gè)線程進(jìn)行本地文件的IO讀取,這樣可以實(shí)現(xiàn)數(shù)據(jù)的讀取和模型的訓(xùn)練是異步的,降低訓(xùn)練時(shí)間。
模型的保存,可以選擇每輪訓(xùn)練都保存,或最佳訓(xùn)練模型保存
#region Train
public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = (ArrayLabel_Train.Length) / batch_size;
var init = tf.global_variables_initializer();
sess.run(init);
var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);
path_model = Name + "\\MODEL";
Directory.CreateDirectory(path_model);
float loss_val = 100.0f;
float accuracy_val = 0f;
var sw = new Stopwatch();
sw.Start();
foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
// Randomly shuffle the training data at the beginning of each epoch
(ArrayFileName_Train, ArrayLabel_Train) = ShuffleArray(ArrayLabel_Train.Length, ArrayFileName_Train, ArrayLabel_Train);
y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)];
//decay learning rate
if (learning_rate_step != 0)
{
if ((epoch != 0) && (epoch % learning_rate_step == 0))
{
learning_rate_base = learning_rate_base * learning_rate_decay;
if (learning_rate_base <= learning_rate_min) { learning_rate_base = learning_rate_min; }
sess.run(tf.assign(learning_rate, learning_rate_base));
}
}
//Load local images asynchronously,use queue,improve train efficiency
BlockingCollection<(NDArray c_x, NDArray c_y, int iter)> BlockC = new BlockingCollection<(NDArray C1, NDArray C2, int iter)>(TrainQueueCapa);
Task.Run(() =>
{
foreach (var iteration in range(num_tr_iter))
{
var start = iteration * batch_size;
var end = (iteration + 1) * batch_size;
(NDArray x_batch, NDArray y_batch) = GetNextBatch(sess, ArrayFileName_Train, y_train, start, end);
BlockC.Add((x_batch, y_batch, iteration));
}
BlockC.CompleteAdding();
});
foreach (var item in BlockC.GetConsumingEnumerable())
{
sess.run(optimizer, (x, item.c_x), (y, item.c_y));
if (item.iter % display_freq == 0)
{
// Calculate and display the batch loss and accuracy
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y));
loss_val = result[0];
accuracy_val = result[1];
print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"));
sw.Restart();
}
}
// Run validation after every epoch
(loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid));
print("CNN:" + "---------------------------------------------------------");
print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
print("CNN:" + "---------------------------------------------------------");
if (SaverBest)
{
if (accuracy_val > max_accuracy)
{
max_accuracy = accuracy_val;
saver.save(sess, path_model + "\\CNN_Best");
print("CKPT Model is save.");
}
}
else
{
saver.save(sess, path_model + string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val));
print("CKPT Model is save.");
}
}
Write_Dictionary(path_model + "\\dic.txt", Dict_Label);
}
private void Write_Dictionary(string path, Dictionary<Int64, string> mydic)
{
FileStream fs = new FileStream(path, FileMode.Create);
StreamWriter sw = new StreamWriter(fs);
foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "\r\n"); }
sw.Flush();
sw.Close();
fs.Close();
print("Write_Dictionary");
}
private (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation(y.shape[0]);
np.random.shuffle(perm);
return (x[perm], y[perm]);
}
private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
var slice = new Slice(start, end);
var x_batch = x[slice];
var y_batch = y[slice];
return (x_batch, y_batch);
}
private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end)
{
NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels);
int n = 0;
for (int i = start; i < end; i++)
{
NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE);
x_batch[n] = sess.run(normalized, (decodeJpeg, img4));
n++;
}
var slice = new Slice(start, end);
var y_batch = y[slice];
return (x_batch, y_batch);
}
#endregion
測(cè)試集預(yù)測(cè)
訓(xùn)練完成的模型對(duì)test數(shù)據(jù)集進(jìn)行預(yù)測(cè),并統(tǒng)計(jì)準(zhǔn)確率
計(jì)算圖中增加了一個(gè)提取預(yù)測(cè)結(jié)果Top-1的概率的節(jié)點(diǎn),最后測(cè)試集預(yù)測(cè)的時(shí)候可以把詳細(xì)的預(yù)測(cè)數(shù)據(jù)進(jìn)行輸出,方便實(shí)際工程中進(jìn)行調(diào)試和優(yōu)化。
public void Test(Session sess)
{
(loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test));
print("CNN:" + "---------------------------------------------------------");
print("CNN:" + $"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
print("CNN:" + "---------------------------------------------------------");
(Test_Cls, Test_Data) = sess.run((cls_prediction, prob), (x, x_test));
}
private void TestDataOutput()
{
for (int i = 0; i < ArrayLabel_Test.Length; i++)
{
Int64 real = ArrayLabel_Test[i];
int predict = (int)(Test_Cls[i]);
var probability = Test_Data[i, predict];
string result = (real == predict) ? "OK" : "NG";
string fileName = ArrayFileName_Test[i];
string real_str = Dict_Label[real];
string predict_str = Dict_Label[predict];
print((i + 1).ToString() + "|" + "result:" + result + "|" + "real_str:" + real_str + "|"
+ "predict_str:" + predict_str + "|" + "probability:" + probability.GetSingle().ToString() + "|"
+ "fileName:" + fileName);
}
}
總結(jié)
本文主要是.NET下的TensorFlow在實(shí)際工業(yè)現(xiàn)場(chǎng)視覺(jué)檢測(cè)項(xiàng)目中的應(yīng)用,使用SciSharp的TensorFlow.NET構(gòu)建了簡(jiǎn)單的CNN圖像分類模型,該模型包含輸入層、卷積與池化層、扁平化層、全連接層和輸出層,這些層都是CNN分類模型的必要的層,針對(duì)工業(yè)現(xiàn)場(chǎng)的實(shí)際圖像進(jìn)行了分類,分類準(zhǔn)確性較高。
完整代碼可以直接用于大家自己的數(shù)據(jù)集進(jìn)行訓(xùn)練,已經(jīng)在工業(yè)現(xiàn)場(chǎng)經(jīng)過(guò)大量測(cè)試,可以在GPU或CPU環(huán)境下運(yùn)行,只需要更換tensorflow.dll文件即可實(shí)現(xiàn)訓(xùn)練環(huán)境的切換。
同時(shí),訓(xùn)練完成的模型文件,可以使用 “CKPT+Meta” 或 凍結(jié)成“PB” 2種方式,進(jìn)行現(xiàn)場(chǎng)的部署,模型部署和現(xiàn)場(chǎng)應(yīng)用推理可以全部在.NET平臺(tái)下進(jìn)行,實(shí)現(xiàn)工業(yè)現(xiàn)場(chǎng)程序的無(wú)縫對(duì)接。擺脫了以往Python下 需要通過(guò)Flask搭建服務(wù)器進(jìn)行數(shù)據(jù)通訊交互 的方式,現(xiàn)場(chǎng)部署應(yīng)用時(shí)無(wú)需配置Python和TensorFlow的環(huán)境【無(wú)需對(duì)工業(yè)現(xiàn)場(chǎng)的原有PC升級(jí)安裝一大堆環(huán)境】,整個(gè)過(guò)程全部使用傳統(tǒng)的.NET的DLL引用的方式。
到此這篇關(guān)于C#使用TensorFlow.NET訓(xùn)練自己的數(shù)據(jù)集的方法的文章就介紹到這了,更多相關(guān)C# TensorFlow.NET訓(xùn)練數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
如何用C#實(shí)現(xiàn)SAGA分布式事務(wù)
大家好,本篇文章主要講的是如何用C#實(shí)現(xiàn)SAGA分布式事務(wù),感興趣的同學(xué)趕快來(lái)看一看吧,對(duì)你有幫助的話記得收藏一下2022-01-01
C#使用SQL DataReader訪問(wèn)數(shù)據(jù)的優(yōu)點(diǎn)和實(shí)例
今天小編就為大家分享一篇關(guān)于C#使用SQL DataReader訪問(wèn)數(shù)據(jù)的優(yōu)點(diǎn)和實(shí)例,小編覺(jué)得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來(lái)看看吧2018-10-10
Unity 使用TexturePacker打包圖集的操作方法
這篇文章主要介紹了Unity 使用TexturePacker打包圖集的操作方法,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-08-08
C#實(shí)現(xiàn)簡(jiǎn)單獲取掃碼槍信息代碼
本文給大家分享的是使用C#實(shí)現(xiàn)簡(jiǎn)單獲取掃碼槍信息代碼,非常的簡(jiǎn)單實(shí)用,有需要的小伙伴可以參考下。2016-07-07
C# 使用相同權(quán)限調(diào)用 cmd 傳入命令的方法
本文告訴大家如何使用相同權(quán)限調(diào)用cmd并且傳入命令,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2018-07-07

