Python TensorFlow 2.6獲取MNIST數(shù)據(jù)的示例代碼
1 Python TensorFlow 2.6 獲取 MNIST 數(shù)據(jù)
1.1 獲取 MNIST 數(shù)據(jù)
獲取 MNIST 數(shù)據(jù)
import numpy as np import tensorflow as tf from tensorflow.keras import datasets print(tf.__version__) (train_data, train_label), (test_data, test_label) = datasets.mnist.load_data() np.savez('D:\\OneDrive\\桌面\\mnist.npz', train_data = train_data, train_label = train_label, test_data = test_data, test_label = test_label)
C:\ProgramData\Anaconda3\envs\tensorflow\python.exe E:/SourceCode/PyCharm/Test/study/exam.py 2.6.0 Process finished with exit code 0
1.2 檢查 MNIST 數(shù)據(jù)
import matplotlib.pyplot as plt import numpy as np data = np.load('D:\\OneDrive\\桌面\\mnist.npz') print(data.files) image = data['train_data'][0:100] label = data['train_label'].reshape(-1, ) print(label) plt.figure(figsize = (10, 10)) for i in range(100): print('%f, %f' % (i, label[i])) plt.subplot(10, 10, i + 1) plt.imshow(image[i]) plt.show()
2 Python 將npz數(shù)據(jù)保存為txt
import numpy as np # 加載mnist數(shù)據(jù) data = np.load('D:\\學(xué)習(xí)\\mnist.npz') # 獲取 訓(xùn)練數(shù)據(jù) train_image = data['x_test'] train_label = data['y_test'] train_image = train_image.reshape(train_image.shape[0], -1) train_image = train_image.astype(np.int32) train_label = train_label.astype(np.int32) train_label = train_label.reshape(-1, 1) index = 0 file = open('D:\\OneDrive\\桌面\\predict.txt', 'w+') for arr in train_image: file.write('{0}->{1}\n'.format(train_label[index][0], ','.join(str(i) for i in arr))) index = index + 1 file.close()
3 Java 獲取數(shù)據(jù)并使用SVM訓(xùn)練
package com.xu.opencv; import java.io.BufferedReader; import java.io.FileReader; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import org.opencv.core.Core; import org.opencv.core.CvType; import org.opencv.core.Mat; import org.opencv.core.TermCriteria; import org.opencv.ml.Ml; import org.opencv.ml.SVM; /** * @author Administrator */ public class Train { static { System.loadLibrary(Core.NATIVE_LIBRARY_NAME); } public static void main(String[] args) throws Exception { predict(); } public static void predict() throws Exception { SVM svm = SVM.load("D:\\OneDrive\\桌面\\ai.xml"); BufferedReader reader = new BufferedReader(new FileReader("D:\\OneDrive\\桌面\\predict.txt")); Mat train = new Mat(6, 28 * 28, CvType.CV_32FC1); Mat label = new Mat(1, 6, CvType.CV_32SC1); Map<String, Mat> map = new HashMap<>(2); int index = 0; String line = null; while ((line = reader.readLine()) != null) { int[] data = Arrays.asList(line.split("->")[1].split(",")).stream().mapToInt(Integer::parseInt).toArray(); for (int i = 0; i < 28 * 28; i++) { train.put(index, i, data[i]); } label.put(index, 0, Integer.parseInt(line.split("->")[0])); index++; if (index >= 6) { break; } } Mat response = new Mat(); svm.predict(train, response); for (int i = 0; i < response.height(); i++) { System.out.println(response.get(i, 0)[0]); } } public static void train() throws Exception { SVM svm = SVM.create(); svm.setC(1); svm.setP(0); svm.setNu(0); svm.setCoef0(0); svm.setGamma(1); svm.setDegree(0); svm.setType(SVM.C_SVC); svm.setKernel(SVM.LINEAR); svm.setTermCriteria(new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 1000, 0)); Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt"); svm.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label")); svm.save("D:\\OneDrive\\桌面\\ai.xml"); } public static Map<String, Mat> read(String path) throws Exception { BufferedReader reader = new BufferedReader(new FileReader(path)); String line = null; Mat train = new Mat(60000, 28 * 28, CvType.CV_32FC1); Mat label = new Mat(1, 60000, CvType.CV_32SC1); Map<String, Mat> map = new HashMap<>(2); int index = 0; while ((line = reader.readLine()) != null) { int[] data = Arrays.asList(line.split("->")[1].split(",")).stream().mapToInt(Integer::parseInt).toArray(); for (int i = 0; i < 28 * 28; i++) { train.put(index, i, data[i]); } label.put(index, 0, Integer.parseInt(line.split("->")[0])); index++; } map.put("train", train); map.put("label", label); reader.close(); return map; } }
4 Python 測試SVM準(zhǔn)確度
9.8% 求幫助
import cv2 as cv import numpy as np # 加載預(yù)測數(shù)據(jù) data = np.load('D:\\學(xué)習(xí)\\mnist.npz') print(data.files) # 預(yù)測數(shù)據(jù) 處理 test_image = data['x_test'] test_label = data['y_test'] test_image = test_image.reshape(test_image.shape[0], -1) test_image = test_image.astype(np.float32) test_label = test_label.astype(np.float32) test_label = test_label.reshape(-1, 1) svm = cv.ml.SVM_load('D:\\OneDrive\\桌面\\ai.xml') predict = svm.predict(test_image) predict = predict[1].reshape(-1, 1).astype(np.int32) result = (predict == test_label.astype(np.int32)) print('{0}%'.format(str(result.mean() * 100)))
C:\ProgramData\Anaconda3\envs\opencv\python.exe E:/SourceCode/PyCharm/OpenCV/svm/predict.py ['x_train', 'y_train', 'x_test', 'y_test'] 9.8% Process finished with exit code 0
以上就是Python TensorFlow 2.6獲取MNIST數(shù)據(jù)的示例代碼的詳細(xì)內(nèi)容,更多關(guān)于Python TensorFlow獲取MNIST的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python字典中g(shù)et()函數(shù)的基本用法實例
在字典內(nèi)置的方法中,想說的方法為get,這個方法是通過鍵來獲取相應(yīng)的值,但是如果相應(yīng)的鍵不存在則返回None,這篇文章主要給大家介紹了關(guān)于python字典中g(shù)et()函數(shù)的基本用法,需要的朋友可以參考下2022-03-03python中split(),?os.path.split()和os.path.splitext()的用法
本文主要介紹了python中split(),?os.path.split()和os.path.splitext()的用法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02Python異常處理如何才能寫得優(yōu)雅(retrying模塊)
異常就是程序運(yùn)行時發(fā)生錯誤的信號,下面這篇文章主要給大家介紹了關(guān)于Python異常處理的相關(guān)資料,文中通過實例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-03-03pytorch通過訓(xùn)練結(jié)果的復(fù)現(xiàn)設(shè)置隨機(jī)種子
這篇文章主要介紹了pytorch通過訓(xùn)練結(jié)果的復(fù)現(xiàn)設(shè)置隨機(jī)種子的操作,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2021-06-06python自動化測試selenium核心技術(shù)三種等待方式詳解
這篇文章主要為大家介紹了python自動化測試selenium的核心技術(shù)三種等待方式示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步早日升職加薪2021-11-11