Python通過樸素貝葉斯和LSTM分別實(shí)現(xiàn)新聞文本分類
一、項(xiàng)目背景
本項(xiàng)目來源于天池?賽,利?機(jī)器學(xué)習(xí)和深度學(xué)習(xí)等知識(shí),對(duì)新聞?本進(jìn)?分類。?共有14個(gè)分類類別:財(cái)經(jīng)、彩票、房產(chǎn)、股票、家居、教育、科技、社會(huì)、時(shí)尚、時(shí)政、體育、星座、游戲、娛樂。
最終將測(cè)試集的預(yù)測(cè)結(jié)果上傳??賽官?,可查看排名。評(píng)價(jià)標(biāo)準(zhǔn)為類別f1_score的均值,提交結(jié)果與實(shí)際測(cè)試集的類別進(jìn)?對(duì)?。(不要求結(jié)果領(lǐng)先,但求真才實(shí)學(xué))
二、數(shù)據(jù)處理與分析
本次大賽提供的材料是由csv格式編寫,只需調(diào)用python中的pandas庫讀取即可。為了更直觀的觀察數(shù)據(jù),我計(jì)算了文檔的平均長度,以及每個(gè)標(biāo)簽分別對(duì)應(yīng)的文檔。(sen字典與tag字典的獲取方法會(huì)在后文中展示,此步只用來呈現(xiàn)數(shù)據(jù)分布,運(yùn)行時(shí)可先跳過)
import matplotlib.pyplot as plt from tqdm import tqdm import time from numpy import * import pandas as pd print('count: 200000') #詞典sen中,每個(gè)標(biāo)簽對(duì)應(yīng)其所有句子的二維列表 print('average: '+str(sum([[sum(sen[i][j]) for j in range(len(sen[i]))] for i in sen])/200000)) x = [] y = [] for key,value in tag.items(): #詞典tag中,每個(gè)標(biāo)簽對(duì)應(yīng)該標(biāo)簽下的句子數(shù)目 x.append(key) y.append(value) plt.bar(x,y) plt.show()
最終我們得到了以下結(jié)果:
平均文檔長約907詞,每個(gè)標(biāo)簽對(duì)應(yīng)的文檔數(shù)從標(biāo)簽0至13逐個(gè)減少。
三、基于機(jī)器學(xué)習(xí)的文本分類–樸素貝葉斯
1. 模型介紹
樸素貝葉斯分類器的基本思想是利用特征項(xiàng)和類別的聯(lián)合概率來估計(jì)給定文檔的類別概率。假設(shè)文本是基于詞的一元模型,即文本中當(dāng)前詞的出現(xiàn)依賴于文本類別,但不依賴于其他詞及文本的長度,也就是說,詞與詞之間是獨(dú)立的。根據(jù)貝葉斯公式,文檔Doc屬于Ci類的概率為
文檔Doc采用TF向量表示法,即文檔向量V的分量為相應(yīng)特征在該文檔中出現(xiàn)的頻度,文檔Doc屬于Ci類文檔的概率為
其中,TF(ti,Doc)是文檔Doc中特征ti出現(xiàn)的頻度,為了防止出現(xiàn)不在詞典中的詞導(dǎo)致概率為0的情況,我們?nèi)(ti|Ci)是對(duì)Ci類文檔中特征ti出現(xiàn)的條件概率的拉普拉斯概率估計(jì):
這里,TF(ti,Ci)是Ci類文檔中特征ti出現(xiàn)的頻度,|V|為特征集的大小,即文檔表示中所包含的不同特征的總數(shù)目。
2. 代碼結(jié)構(gòu)
我直接通過python自帶的open()函數(shù)讀取文件,并建立對(duì)應(yīng)詞典,設(shè)定停用詞,這里的停用詞選擇了words字典中出現(xiàn)在100000個(gè)文檔以上的所有詞。訓(xùn)練集取前19萬個(gè)文檔,測(cè)試集取最后一萬個(gè)文檔。
train_df = open('./data/train_set.csv').readlines()[1:] train = train_df[0:190000] test = train_df[190000:200000] true_test = open('./data/test_a.csv').readlines()[1:] tag = {str(i):0 for i in range(0,14)} sen = {str(i):{} for i in range(0,14)} words={} stop_words = {'4149': 1, '1519': 1, '2465': 1, '7539': 1, ...... }
接著,我們需要建立標(biāo)簽詞典和句子詞典,用tqdm函數(shù)來顯示進(jìn)度。
for line in tqdm(train_df): cur_line = line.split('\t') cur_tag = cur_line[0] tag[cur_tag] += 1 cur_line = cur_line[1][:-1].split(' ') for i in cur_line: if i not in words: words[i] = 1 else: words[i] += 1 if i not in sen[cur_tag]: sen[cur_tag][i] = 1 else: sen[cur_tag][i] += 1
為了便于計(jì)算,我定義了如下函數(shù),其中mul()用來計(jì)算列表中所有數(shù)的乘積,prob_clas() 用來計(jì)算P(Ci|Doc),用probability()來計(jì)算P(ti|Ci),在probability() 函數(shù)中,我將輸出結(jié)果中分子+1,分母加上字典長度,實(shí)現(xiàn)拉普拉斯平滑處理。
def mul(l): res = 1 for i in l: res *= i return res def prob_clas(clas): return tag[clas]/(sum([tag[i] for i in tag])) def probability(char,clas): #P(特征|類別) if char not in sen[clas]: num_char = 0 else: num_char = sen[clas][char] return (1+num_char)/(len(sen[clas])+len(words))
在做好所有準(zhǔn)備工作,定義好函數(shù)后,分別對(duì)測(cè)試集中的每一句話計(jì)算十四個(gè)標(biāo)簽對(duì)應(yīng)概率,并將概率最大的標(biāo)簽儲(chǔ)存在預(yù)測(cè)列表中,用tqdm函數(shù)來顯示進(jìn)度。
PRED = [] for line in tqdm(true_test): result = {str(i):0 for i in range(0,14)} cur_line = line[:-1].split(' ') clas = cur_tag for i in result: prob = [] for j in cur_line: if j in stop_words: continue prob.append(log(probability(j,i))) result[i] = log(prob_clas(i))+sum(prob) for key,value in result.items(): if(value == max(result.values())): pred = int(key) PRED.append(pred)
最后把結(jié)果儲(chǔ)存在csv文件中上傳網(wǎng)站,提交后查看成績(jī)。(用此方法編寫的csv文件需要打開后刪去第一列再上傳)
res=pd.DataFrame() res['label']=PRED res.to_csv('test_TL.csv')
3. 結(jié)果分析
在訓(xùn)練前19萬個(gè)文檔,測(cè)試后一萬個(gè)文檔的過程中,我不斷調(diào)整停用詞取用列表,分別用TF和TF-IDF向量表示法進(jìn)行了測(cè)試,結(jié)果發(fā)現(xiàn)使用TF表示法準(zhǔn)確性較高,最后取用停用詞為出現(xiàn)在十萬個(gè)文檔以上的詞。最終得出最高效率為0.622。
在提交至網(wǎng)站后,對(duì)五萬個(gè)文檔進(jìn)行測(cè)試的F1值僅有 0.29左右,效果較差。
四、基于深度學(xué)習(xí)的文本分類–LSTM
1. 模型介紹
除了傳統(tǒng)的機(jī)器學(xué)習(xí)方法,我使用了深度學(xué)習(xí)中的LSTM(Long Short-Term Memory)長短期記憶網(wǎng)絡(luò),來嘗試處理新聞文本分類,希望能有更高的準(zhǔn)確率。LSTM它是一種時(shí)間循環(huán)神經(jīng)網(wǎng)絡(luò),適合于處理和預(yù)測(cè)時(shí)間序列中間隔和延遲相對(duì)較長的重要事件。LSTM 已經(jīng)在科技領(lǐng)域有了多種應(yīng)用?;?LSTM 的系統(tǒng)可以學(xué)習(xí)翻譯語言、控制機(jī)器人、圖像分析、文檔摘要、語音識(shí)別圖像識(shí)別、手寫識(shí)別、控制聊天機(jī)器人、預(yù)測(cè)疾病、點(diǎn)擊率和股票、合成音樂等等任務(wù)。我采用深度學(xué)習(xí)庫Keras來建立LSTM模型,進(jìn)行文本分類。
對(duì)于卷積神經(jīng)網(wǎng)絡(luò)CNN和循環(huán)網(wǎng)絡(luò)RNN而言,隨著時(shí)間的不斷增加,隱藏層一次又一次地乘以權(quán)重W。假如某個(gè)權(quán)重w是一個(gè)接近于0或者大于1的數(shù),隨著乘法次數(shù)的增加,這個(gè)權(quán)重值會(huì)變得很小或者很大,造成反向傳播時(shí)梯度計(jì)算變得很困難,造成梯度爆炸或者梯度消失的情況,模型難以訓(xùn)練。也就是說一般的RNN模型對(duì)于長時(shí)間距離的信息記憶很差,因此LSTM應(yīng)運(yùn)而生。
LSTM長短期記憶網(wǎng)絡(luò)可以更好地解決這個(gè)問題。在LSTM的一個(gè)單元中,有四個(gè)顯示為黃色框的網(wǎng)絡(luò)層,每個(gè)層都有自己的權(quán)重,如以 σ 標(biāo)記的層是 sigmoid 層,tanh是一個(gè)激發(fā)函數(shù)。這些紅圈表示逐點(diǎn)或逐元素操作。單元狀態(tài)在通過 LSTM 單元時(shí)幾乎沒有交互,使得大部分信息得以保留,單元狀態(tài)僅通過這些控制門(gate)進(jìn)行修改。第一個(gè)控制門是遺忘門,用來決定我們會(huì)從單元狀態(tài)中丟棄什么信息。第二個(gè)門是更新門,用以確定什么樣的新信息被存放到單元狀態(tài)中。最后一個(gè)門是輸出門,我們需要確定輸出什么樣的值。總結(jié)來說 LSTM 單元由單元狀態(tài)和一堆用于更新信息的控制門組成,讓信息部分傳遞到隱藏層狀態(tài)。
2. 代碼結(jié)構(gòu)
首先是初始數(shù)據(jù)的設(shè)定和包的調(diào)用。考慮到平均句長約900,這里取最大讀取長度為平均長度的2/3,即max_len為600,之后可通過調(diào)整該參數(shù)來調(diào)整學(xué)習(xí)效率。
from tqdm import tqdm import pandas as pd import time import matplotlib.pyplot as plt import seaborn as sns from numpy import * from sklearn import metrics from sklearn.preprocessing import LabelEncoder,OneHotEncoder from keras.models import Model from keras.layers import LSTM, Activation, Dense, Dropout, Input, Embedding from keras.optimizers import rmsprop_v2 from keras.preprocessing import sequence from keras.callbacks import EarlyStopping from keras.models import load_model import os.path max_words = 7549 #字典最大編號(hào) # 可通過調(diào)節(jié)max_len調(diào)整模型效果和學(xué)習(xí)速度 max_len = 600 #句子的最大長度 stop_words = {}
接下來,我們定義一個(gè)將DataFrame的格式轉(zhuǎn)化為矩陣的函數(shù)。該函數(shù)輸出一個(gè)長度為600的二維文檔列表和其對(duì)應(yīng)的標(biāo)簽值。
def to_seq(dataframe): x = [] y = array([[0]*int(i)+[1]+[0]*(13-int(i)) for i in dataframe['label']]) for i in tqdm(dataframe['text']): cur_sentense = [] for word in i.split(' '): if word not in stop_words: #最終并未采用停用詞列表 cur_sentense.append(word) x.append(cur_sentense) return sequence.pad_sequences(x,maxlen=max_len),y
接下來是模型的主體函數(shù)。該函數(shù)輸入測(cè)試的文檔,測(cè)試集的真值,訓(xùn)練集和檢驗(yàn)集,輸出預(yù)測(cè)得到的混淆矩陣。具體代碼介紹,見下列代碼中的注釋。
def test_file(text,value,train,val): ## 定義LSTM模型 inputs = Input(name='inputs',shape=[max_len]) ## Embedding(詞匯表大小,batch大小,每個(gè)新聞的詞長) layer = Embedding(max_words+1,128,input_length=max_len)(inputs) layer = LSTM(128)(layer) layer = Dense(128,activation="relu",name="FC1")(layer) layer = Dropout(0.5)(layer) layer = Dense(14,activation="softmax",name="FC2")(layer) model = Model(inputs=inputs,outputs=layer) model.summary() model.compile(loss="categorical_crossentropy",optimizer=rmsprop_v2.RMSprop(),metrics=["accuracy"]) ## 模型建立好之后開始訓(xùn)練,如果已經(jīng)保存訓(xùn)練文件(.h5格式),則直接調(diào)取即可 if os.path.exists('my_model.h5') == True: model = load_model('my_model.h5') else: train_seq_mat,train_y = to_seq(train) val_seq_mat,val_y = to_seq(val) model.fit(train_seq_mat,train_y,batch_size=128,epochs=10, #可通過epochs數(shù)來調(diào)整準(zhǔn)確率和運(yùn)算速度 validation_data=(val_seq_mat,val_y)) model.save('my_model.h5') ## 開始預(yù)測(cè) test_pre = model.predict(text) ##計(jì)算混淆函數(shù) confm = metrics.confusion_matrix(argmax(test_pre,axis=1),argmax(value,axis=1)) print(metrics.classification_report(argmax(test_pre,axis=1),argmax(value,axis=1))) return confm
訓(xùn)練過程如下圖所示。
為了更直觀的表現(xiàn)結(jié)果,定義如下函數(shù)繪制圖像。
def plot_fig(matrix): Labname = [str(i) for i in range(14)] plt.figure(figsize=(8,8)) sns.heatmap(matrix.T, square=True, annot=True, fmt='d', cbar=False,linewidths=.8, cmap="YlGnBu") plt.xlabel('True label',size = 14) plt.ylabel('Predicted label',size = 14) plt.xticks(arange(14)+0.5,Labname,size = 12) plt.yticks(arange(14)+0.3,Labname,size = 12) plt.show() return
最后,只需要通過pandas讀取csv文件,按照比例分為訓(xùn)練集、檢驗(yàn)集和測(cè)試集(這里選用比例為15:2:3),即可完成全部的預(yù)測(cè)過程。
def test_main(): train_df = pd.read_csv("./data/train_set.csv",sep='\t',nrows=200000) train = train_df.iloc[0:150000,:] test = train_df.iloc[150000:180000,:] val = train_df.iloc[180000:,:] test_seq_mat,test_y = to_seq(test) Confm = test_file(test_seq_mat,test_y,train,val) plot_fig(Confm)
在獲得預(yù)測(cè)結(jié)果最高的一組參數(shù)的選取后,我們訓(xùn)練整個(gè)train_set文件,訓(xùn)練過程如下,訓(xùn)練之前需刪除已有的訓(xùn)練文件(.h5),此函數(shù)中的test行可隨意選取,只是為了滿足test_file()函數(shù)的變量足夠。此函數(shù)只是用于訓(xùn)練出學(xué)習(xí)效果最好的數(shù)據(jù)并儲(chǔ)存。
def train(): train_df = pd.read_csv("./data/train_set.csv",sep='\t',nrows=200000) train = train_df.iloc[0:170000,:] test = train_df.iloc[0:10000,:] val = train_df.iloc[170000:,:] test_seq_mat,test_y = to_seq(test) Confm = test_file(test_seq_mat,test_y,train,val) plot_fig(Confm)
在獲得最優(yōu)的訓(xùn)練數(shù)據(jù)后,我們就可以開始預(yù)測(cè)了。我們將競(jìng)賽中提供的測(cè)試集帶入模型中,加載儲(chǔ)存好的訓(xùn)練集進(jìn)行預(yù)測(cè),得到預(yù)測(cè)矩陣。再將預(yù)測(cè)矩陣中每一行的最大值轉(zhuǎn)化為對(duì)應(yīng)的標(biāo)簽,儲(chǔ)存在輸出列表中即可,最后將該列表寫入'test_DL.csv'文件中上傳即可。(如此生成的csv文件同上一個(gè)模型一樣,需手動(dòng)打開刪除掉第一列)
def pred_file(): test_df = pd.read_csv('./data/test_a.csv') test_seq_mat = sequence.pad_sequences([i.split(' ') for i in tqdm(test_df['text'])],maxlen=max_len) inputs = Input(name='inputs',shape=[max_len]) ## Embedding(詞匯表大小,batch大小,每個(gè)新聞的詞長) layer = Embedding(max_words+1,128,input_length=max_len)(inputs) layer = LSTM(128)(layer) layer = Dense(128,activation="relu",name="FC1")(layer) layer = Dropout(0.5)(layer) layer = Dense(14,activation="softmax",name="FC2")(layer) model = Model(inputs=inputs,outputs=layer) model.summary() model.compile(loss="categorical_crossentropy",optimizer=rmsprop_v2.RMSprop(),metrics=["accuracy"]) model = load_model('my_model.h5') test_pre = model.predict(test_seq_mat) pred_result = [i.tolist().index(max(i.tolist())) for i in test_pre] res=pd.DataFrame() res['label']=pred_result res.to_csv('test_DL.csv')
整理后,我們只需要注釋掉對(duì)應(yīng)的指令行即可進(jìn)行訓(xùn)練或預(yù)測(cè)。
#如果想要訓(xùn)練,取消下行注釋,訓(xùn)練之前需先刪除原訓(xùn)練文件(.h5) #train() #如果想要查看模型效果,取消下行注釋(訓(xùn)練集:檢驗(yàn)集:測(cè)試集=15:2:3) # test_main() #如果想預(yù)測(cè)并生成csv文件,取消下行注釋 # pred_file()
3. 結(jié)果分析
最終獲得的混淆矩陣如下圖所示,14個(gè)標(biāo)簽預(yù)測(cè)的正確率均達(dá)到了80%以上,有11個(gè)標(biāo)簽在90%以上,有6個(gè)標(biāo)簽在95%以上。
繪制出來的預(yù)測(cè)結(jié)果如下圖所示,可見預(yù)測(cè)效果相當(dāng)理想,每個(gè)標(biāo)簽的正確率都尤為可觀,預(yù)測(cè)錯(cuò)誤的文本數(shù)相比于總量非常少。
最終上傳網(wǎng)站得到結(jié)果,F(xiàn)1值達(dá)90%以上,效果較好。
五、小結(jié)
本實(shí)驗(yàn)采用了傳統(tǒng)機(jī)器學(xué)習(xí)和基于LSTM的深度學(xué)習(xí)兩種方法對(duì)新聞文本進(jìn)行了分類,在兩種方法的對(duì)比下,深度學(xué)習(xí)的效果明顯優(yōu)于傳統(tǒng)的機(jī)器學(xué)習(xí),并在競(jìng)賽中取得了較好的成績(jī)(排名551)。但LSTM仍存在問題,一方面是RNN的梯度問題在LSTM及其變種里面得到了一定程度的解決,但還是不夠;另一方面,LSTM計(jì)算費(fèi)時(shí),每一個(gè)LSTM的cell里面都意味著有4個(gè)全連接層(MLP),如果LSTM的時(shí)間跨度很大,并且網(wǎng)絡(luò)又很深,這個(gè)計(jì)算量會(huì)很大,很耗時(shí)。
探尋更好的文本分類方法一直以來都是NLP在探索的方向,希望今后可以學(xué)習(xí)更多的分類方法,更多的機(jī)器學(xué)習(xí)和深度學(xué)習(xí)模型,提高分類效率。
到此這篇關(guān)于Python通過樸素貝葉斯和LSTM分別實(shí)現(xiàn)新聞文本分類的文章就介紹到這了,更多相關(guān)Python文本分類內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- Python機(jī)器學(xué)習(xí)應(yīng)用之樸素貝葉斯篇
- python機(jī)器學(xué)習(xí)樸素貝葉斯算法及模型的選擇和調(diào)優(yōu)詳解
- python實(shí)現(xiàn)貝葉斯推斷的例子
- python 實(shí)現(xiàn)樸素貝葉斯算法的示例
- Python實(shí)現(xiàn)樸素貝葉斯的學(xué)習(xí)與分類過程解析
- python實(shí)現(xiàn)基于樸素貝葉斯的垃圾分類算法
- python實(shí)現(xiàn)樸素貝葉斯算法
- 樸素貝葉斯Python實(shí)例及解析
- Python Multinomial Naive Bayes多項(xiàng)貝葉斯模型實(shí)現(xiàn)原理介紹
相關(guān)文章
python神經(jīng)網(wǎng)絡(luò)Keras構(gòu)建CNN網(wǎng)絡(luò)訓(xùn)練
這篇文章主要為大家介紹了python神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)使用Keras構(gòu)建CNN網(wǎng)絡(luò)訓(xùn)練,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05使用Python實(shí)現(xiàn)windows下的抓包與解析
這篇文章主要介紹了使用Python實(shí)現(xiàn)windows下的抓包與解析,非常不錯(cuò),具有參考借鑒價(jià)值,需要的朋友可以參考下2018-01-01基于Python的接口自動(dòng)化unittest測(cè)試框架和ddt數(shù)據(jù)驅(qū)動(dòng)詳解
這篇文章主要介紹了基于Python的接口自動(dòng)化unittest測(cè)試框架和ddt數(shù)據(jù)驅(qū)動(dòng)詳解,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-01-01python 借助numpy保存數(shù)據(jù)為csv格式的實(shí)現(xiàn)方法
今天小編就為大家分享一篇python 借助numpy保存數(shù)據(jù)為csv格式的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-07-07matplotlib一維散點(diǎn)分布圖的實(shí)現(xiàn)
本文主要介紹了matplotlib一維散點(diǎn)分布圖的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03在linux系統(tǒng)中安裝python3.8.1?并卸載?python3.6.2?更新python3引導(dǎo)到3.8.1的
這篇文章主要介紹了如何在linux系統(tǒng)中安裝python3.8.1?并卸載?python3.6.2?更新python3引導(dǎo)到3.8.1,本文分步驟給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-11-11