Python機器學習應用之基于線性判別模型的分類篇詳解
一、Introduction
線性判別模型(LDA)在模式識別領域(比如人臉識別等圖形圖像識別領域)中有非常廣泛的應用。LDA是一種監(jiān)督學習的降維技術,也就是說它的數(shù)據(jù)集的每個樣本是有類別輸出的。這點和PCA不同。PCA是不考慮樣本類別輸出的無監(jiān)督降維技術。 LDA的思想可以用一句話概括,就是“投影后類內方差最小,類間方差最大”。我們要將數(shù)據(jù)在低維度上進行投影,投影后希望每一種類別數(shù)據(jù)的投影點盡可能的接近,而不同類別的數(shù)據(jù)的類別中心之間的距離盡可能的大。即:將數(shù)據(jù)投影到維度更低的空間中,使得投影后的點,會形成按類別區(qū)分,一簇一簇的情況,相同類別的點,將會在投影后的空間中更接近方法。
1 LDA的優(yōu)點
- 在降維過程中可以使用類別的先驗知識經驗,而像PCA這樣的無監(jiān)督學習則無法使用類別先驗知識;
- LDA在樣本分類信息依賴均值而不是方差的時候,比PCA之類的算法較優(yōu)
2 LDA的缺點
- LDA不適合對非高斯分布樣本進行降維,PCA也有這個問題
- LDA降維最多降到類別數(shù) k-1 的維數(shù),如果我們降維的維度大于 k-1,則不能使用 LDA。當然目前有一些LDA的進化版算法可以繞過這個問題
- LDA在樣本分類信息依賴方差而不是均值的時候,降維效果不好
- LDA可能過度擬合數(shù)據(jù)
3 LDA在模式識別領域與自然語言處理領域的區(qū)別
在自然語言處理領域,LDA是隱含狄利克雷分布,它是一種處理文檔的主題模型。本文討論的是線性判別分析 LDA除了可以用于降維以外,還可以用于分類。一個常見的LDA分類基本思想是假設各個類別的樣本數(shù)據(jù)符合高斯分布,這樣利用LDA進行投影后,可以利用極大似然估計計算各個類別投影數(shù)據(jù)的均值和方差,進而得到該類別高斯分布的概率密度函數(shù)。當一個新的樣本到來后,我們可以將它投影,然后將投影后的樣本特征分別帶入各個類別的高斯分布概率密度函數(shù),計算它屬于這個類別的概率,最大的概率對應的類別即為預測類別
二、Demo
#%%導入基本庫 # 基礎數(shù)組運算庫導入 import numpy as np # 畫圖庫導入 import matplotlib.pyplot as plt # 導入三維顯示工具 from mpl_toolkits.mplot3d import Axes3D # 導入LDA模型 from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # 導入demo數(shù)據(jù)制作方法 from sklearn.datasets import make_classification #%%模型訓練 # 制作四個類別的數(shù)據(jù),每個類別100個樣本 X, y = make_classification(n_samples=1000, n_features=3, n_redundant=0, n_classes=4, n_informative=2, n_clusters_per_class=1, class_sep=3, random_state=10) # 將四個類別的數(shù)據(jù)進行三維顯示 fig = plt.figure() ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20) ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='o', c=y) plt.show()
#%%建立 LDA 模型 lda = LinearDiscriminantAnalysis() # 進行模型訓練 lda.fit(X, y) #%%查看lda的參數(shù) print(lda.get_params())
#%%數(shù)據(jù)可視化 #模型預測 X_new = lda.transform(X) # 可視化預測數(shù)據(jù) plt.scatter(X_new[:, 0], X_new[:, 1], marker='o', c=y) plt.show()
#%%使用新的數(shù)據(jù)進行測試 a = np.array([[-1, 0.1, 0.1]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a)) a = np.array([[-12, -100, -91]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a)) a = np.array([[-12, -0.1, -0.1]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a)) a = np.array([[0.1, 90.1, 9.1]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a))
三、基于LDA 手寫數(shù)字的分類
#%%導入庫函數(shù) # 導入手寫數(shù)據(jù)集 MNIST from sklearn.datasets import load_digits # 導入訓練集分割方法 from sklearn.model_selection import train_test_split # 導入LDA模型 from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # 導入預測指標計算函數(shù)和混淆矩陣計算函數(shù) from sklearn.metrics import classification_report, confusion_matrix # 導入繪圖包 import seaborn as sns import matplotlib import matplotlib.pyplot as plt #%% 導入MNIST數(shù)據(jù)集 mnist = load_digits() # 查看數(shù)據(jù)集信息 print('The Mnist dataeset:\n',mnist) # 分割數(shù)據(jù)為訓練集和測試集 x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=2)
#%%## 輸出示例圖像 images = range(0,9) plt.figure(dpi=100) for i in images: plt.subplot(330 + 1 + i) plt.imshow(x[i].reshape(8, 8), cmap = matplotlib.cm.binary,interpolation="nearest") # show the plot plt.show()
#%%利用LDA對手寫數(shù)字進行訓練與預測 m_lda = LinearDiscriminantAnalysis()# 建立 LDA 模型 # 進行模型訓練 m_lda.fit(x, y) # 進行模型預測 x_new = m_lda.transform(x) # 可視化預測數(shù)據(jù) plt.scatter(x_new[:, 0], x_new[:, 1], marker='o', c=y) plt.title('MNIST with LDA Model') plt.show()
#%% 進行測試集數(shù)據(jù)的類別預測 y_test_pred = m_lda.predict(test_x) print("測試集的真實標簽:\n", test_y) print("測試集的預測標簽:\n", y_test_pred) #%% 進行預測結果指標統(tǒng)計 統(tǒng)計每一類別的預測準確率、召回率、F1分數(shù) print(classification_report(test_y, y_test_pred)) # 計算混淆矩陣 C2 = confusion_matrix(test_y, y_test_pred) # 打混淆矩陣 print(C2) # 將混淆矩陣以熱力圖的防線顯示 sns.set() f, ax = plt.subplots() # 畫熱力圖 sns.heatmap(C2, cmap="YlGnBu_r", annot=True, ax=ax) # 標題 ax.set_title('confusion matrix') # x軸為預測類別 ax.set_xlabel('predict') # y軸實際類別 ax.set_ylabel('true') plt.show()
四、小結
LDA適用于線性可分數(shù)據(jù),在非線性數(shù)據(jù)上要謹慎使用。 886~~~
到此這篇關于Python機器學習應用之基于線性判別模型的分類篇詳解的文章就介紹到這了,更多相關Python 線性判別模型的分類內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Django+Nginx+uWSGI 定時任務的實現(xiàn)方法
本文主要介紹了Django+Nginx+uWSGI 定時任務的實現(xiàn)方法,文中通過示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下2022-01-01PyQt5 QSerialPort子線程操作的實現(xiàn)
這篇文章主要介紹了PyQt5 QSerialPort子線程操作的實現(xiàn),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-04-04selenium+python實現(xiàn)基本自動化測試的示例代碼
這篇文章主要介紹了selenium+python實現(xiàn)基本自動化測試的示例代碼,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2021-01-01pyinstaller打包可執(zhí)行文件出現(xiàn)KeyError的問題
這篇文章主要介紹了pyinstaller打包可執(zhí)行文件出現(xiàn)KeyError的問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-11-11?分享一個Python?遇到數(shù)據(jù)庫超好用的模塊
這篇文章主要介紹了?分享一個Python?遇到數(shù)據(jù)庫超好用的模塊,SQLALchemy這個模塊,該模塊是Python當中最有名的ORM框架,該框架是建立在數(shù)據(jù)庫API之上,使用關系對象映射進行數(shù)據(jù)庫的操作,,需要的朋友可以參考下2022-04-04