pytorch實(shí)現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式
更新時(shí)間:2020年01月08日 09:28:17 作者:Alphapeople
今天小編就為大家分享一篇pytorch實(shí)現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
model.py:
#!/usr/bin/python
# -*- coding: utf-8 -*-
import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
class TextRNN(nn.Module):
"""文本分類,RNN模型"""
def __init__(self):
super(TextRNN, self).__init__()
# 三個(gè)待輸入的數(shù)據(jù)
self.embedding = nn.Embedding(5000, 64) # 進(jìn)行詞嵌入
# self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
self.f1 = nn.Sequential(nn.Linear(256,128),
nn.Dropout(0.8),
nn.ReLU())
self.f2 = nn.Sequential(nn.Linear(128,10),
nn.Softmax())
def forward(self, x):
x = self.embedding(x)
x,_ = self.rnn(x)
x = F.dropout(x,p=0.8)
x = self.f1(x[:,-1,:])
return self.f2(x)
class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__()
self.embedding = nn.Embedding(5000,64)
self.conv = nn.Conv1d(64,256,5)
self.f1 = nn.Sequential(nn.Linear(256*596, 128),
nn.ReLU())
self.f2 = nn.Sequential(nn.Linear(128, 10),
nn.Softmax())
def forward(self, x):
x = self.embedding(x)
x = x.detach().numpy()
x = np.transpose(x,[0,2,1])
x = torch.Tensor(x)
x = Variable(x)
x = self.conv(x)
x = x.view(-1,256*596)
x = self.f1(x)
return self.f2(x)
train.py:
# coding: utf-8
from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import os
import numpy as np
from model import TextRNN,TextCNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
def train():
x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#獲取訓(xùn)練數(shù)據(jù)每個(gè)字的id和對應(yīng)標(biāo)簽的oe-hot形式
x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600)
#使用LSTM或者CNN
model = TextRNN()
# model = TextCNN()
#選擇損失函數(shù)
Loss = nn.MultiLabelSoftMarginLoss()
# Loss = nn.BCELoss()
# Loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)
best_val_acc = 0
for epoch in range(1000):
batch_train = batch_iter(x_train, y_train,100)
for x_batch, y_batch in batch_train:
x = np.array(x_batch)
y = np.array(y_batch)
x = torch.LongTensor(x)
y = torch.Tensor(y)
# y = torch.LongTensor(y)
x = Variable(x)
y = Variable(y)
out = model(x)
loss = Loss(out,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy())
#對模型進(jìn)行驗(yàn)證
if (epoch+1)%20 == 0:
batch_val = batch_iter(x_val, y_val, 100)
for x_batch, y_batch in batch_train:
x = np.array(x_batch)
y = np.array(y_batch)
x = torch.LongTensor(x)
y = torch.Tensor(y)
# y = torch.LongTensor(y)
x = Variable(x)
y = Variable(y)
out = model(x)
loss = Loss(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy())
if accracy > best_val_acc:
torch.save(model.state_dict(),'model_params.pkl')
best_val_acc = accracy
print(accracy)
if __name__ == '__main__':
#獲取文本的類別及其對應(yīng)id的字典
categories, cat_to_id = read_category()
#獲取訓(xùn)練文本中所有出現(xiàn)過的字及其所對應(yīng)的id
words, word_to_id = read_vocab(vocab_dir)
#獲取字?jǐn)?shù)
vocab_size = len(words)
train()
test.py:
# coding: utf-8
from __future__ import print_function
import os
import tensorflow.contrib.keras as kr
import torch
from torch import nn
from cnews_loader import read_category, read_vocab
from model import TextRNN
from torch.autograd import Variable
import numpy as np
try:
bool(type(unicode))
except NameError:
unicode = str
base_dir = 'cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__()
self.embedding = nn.Embedding(5000,64)
self.conv = nn.Conv1d(64,256,5)
self.f1 = nn.Sequential(nn.Linear(152576, 128),
nn.ReLU())
self.f2 = nn.Sequential(nn.Linear(128, 10),
nn.Softmax())
def forward(self, x):
x = self.embedding(x)
x = x.detach().numpy()
x = np.transpose(x,[0,2,1])
x = torch.Tensor(x)
x = Variable(x)
x = self.conv(x)
x = x.view(-1,152576)
x = self.f1(x)
return self.f2(x)
class CnnModel:
def __init__(self):
self.categories, self.cat_to_id = read_category()
self.words, self.word_to_id = read_vocab(vocab_dir)
self.model = TextCNN()
self.model.load_state_dict(torch.load('model_params.pkl'))
def predict(self, message):
# 支持不論在python2還是python3下訓(xùn)練的模型都可以在2或者3的環(huán)境下運(yùn)行
content = unicode(message)
data = [self.word_to_id[x] for x in content if x in self.word_to_id]
data = kr.preprocessing.sequence.pad_sequences([data],600)
data = torch.LongTensor(data)
y_pred_cls = self.model(data)
class_index = torch.argmax(y_pred_cls[0]).item()
return self.categories[class_index]
class RnnModel:
def __init__(self):
self.categories, self.cat_to_id = read_category()
self.words, self.word_to_id = read_vocab(vocab_dir)
self.model = TextRNN()
self.model.load_state_dict(torch.load('model_rnn_params.pkl'))
def predict(self, message):
# 支持不論在python2還是python3下訓(xùn)練的模型都可以在2或者3的環(huán)境下運(yùn)行
content = unicode(message)
data = [self.word_to_id[x] for x in content if x in self.word_to_id]
data = kr.preprocessing.sequence.pad_sequences([data], 600)
data = torch.LongTensor(data)
y_pred_cls = self.model(data)
class_index = torch.argmax(y_pred_cls[0]).item()
return self.categories[class_index]
if __name__ == '__main__':
model = CnnModel()
# model = RnnModel()
test_demo = ['湖人助教力助科比恢復(fù)手感 他也是阿泰的精神導(dǎo)師新浪體育訊記者戴高樂報(bào)道 上賽季,科比的右手食指遭遇重創(chuàng),他的投籃手感也因此大受影響。不過很快科比就調(diào)整了自己的投籃手型,并通過這一方式讓自己的投籃命中率回升。而在這科比背后,有一位特別助教對科比幫助很大,他就是查克·珀森。珀森上賽季擔(dān)任湖人的特別助教,除了幫助科比調(diào)整投籃手型之外,他的另一個(gè)重要任務(wù)就是擔(dān)任阿泰的精神導(dǎo)師。來到湖人隊(duì)之后,阿泰收斂起了暴躁的脾氣,成為湖人奪冠路上不可或缺的一員,珀森的“心靈按摩”功不可沒。經(jīng)歷了上賽季的成功之后,珀森本賽季被“升職”成為湖人隊(duì)的全職助教,每場比賽,他都會(huì)坐在球場邊,幫助禪師杰克遜一起指揮湖人球員在場上拼殺。對于珀森的工作,禪師非常欣賞,“查克非常善于分析問題,”菲爾·杰克遜說,“他總是在尋找問題的答案,同時(shí)也在找造成這一問題的原因,這是我們都非常樂于看到的。我會(huì)在平時(shí)把防守中出現(xiàn)的一些問題交給他,然后他會(huì)通過組織球員練習(xí)找到解決的辦法。他在球員時(shí)代曾是一名很好的外線投手,不過現(xiàn)在他與內(nèi)線球員的配合也相當(dāng)不錯(cuò)。',
'弗老大被裁美國媒體看熱鬧“特權(quán)”在中國像蠢蛋弗老大要走了。雖然他只在首鋼男籃效力了13天,而且表現(xiàn)毫無亮點(diǎn),大大地讓球迷和俱樂部失望了,但就像中國人常說的“好聚好散”,隊(duì)友還是友好地與他告別,俱樂部與他和平分手,球迷還請他留下了在北京的最后一次簽名。相比之下,弗老大的同胞美國人卻沒那么“寬容”。他們嘲諷這位NBA前巨星的英雄遲暮,批評他在CBA的業(yè)余表現(xiàn),還驚訝于中國人的“大方”。今天,北京首鋼俱樂部將與弗朗西斯繼續(xù)商討解約一事。從昨日的進(jìn)展來看,雙方可以做到“買賣不成人意在”,但回到美國后,恐怕等待弗朗西斯的就沒有這么輕松的環(huán)境了。進(jìn)展@北京昨日與隊(duì)友告別 最后一次為球迷簽名弗朗西斯在13天里為首鋼隊(duì)打了4場比賽,3場的得分為0,只有一場得了2分。昨天是他來到北京的第14天,雖然他與首鋼還未正式解約,但雙方都明白“緣分已盡”。下午,弗朗西斯來到首鋼俱樂部與隊(duì)友們告別。弗朗西斯走到隊(duì)友身邊,依次與他們握手擁抱?!澳銈兌紝ξ液芎?,安排的條件也很好,我很喜歡這支球隊(duì),想融入你們,但我現(xiàn)在真的很不適應(yīng)。希望你們']
for i in test_demo:
print(i,":",model.predict(i))
cnews_loader.py:
# coding: utf-8
import sys
from collections import Counter
import numpy as np
import tensorflow.contrib.keras as kr
if sys.version_info[0] > 2:
is_py3 = True
else:
reload(sys)
sys.setdefaultencoding("utf-8")
is_py3 = False
def native_word(word, encoding='utf-8'):
"""如果在python2下面使用python3訓(xùn)練的模型,可考慮調(diào)用此函數(shù)轉(zhuǎn)化一下字符編碼"""
if not is_py3:
return word.encode(encoding)
else:
return word
def native_content(content):
if not is_py3:
return content.decode('utf-8')
else:
return content
def open_file(filename, mode='r'):
"""
常用文件操作,可在python2和python3間切換.
mode: 'r' or 'w' for read or write
"""
if is_py3:
return open(filename, mode, encoding='utf-8', errors='ignore')
else:
return open(filename, mode)
def read_file(filename):
"""讀取文件數(shù)據(jù)"""
contents, labels = [], []
with open_file(filename) as f:
for line in f:
try:
label, content = line.strip().split('\t')
if content:
contents.append(list(native_content(content)))
labels.append(native_content(label))
except:
pass
return contents, labels
def build_vocab(train_dir, vocab_dir, vocab_size=5000):
"""根據(jù)訓(xùn)練集構(gòu)建詞匯表,存儲"""
data_train, _ = read_file(train_dir)
all_data = []
for content in data_train:
all_data.extend(content)
counter = Counter(all_data)
count_pairs = counter.most_common(vocab_size - 1)
words, _ = list(zip(*count_pairs))
# 添加一個(gè) <PAD> 來將所有文本pad為同一長度
words = ['<PAD>'] + list(words)
open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
def read_vocab(vocab_dir):
"""讀取詞匯表"""
# words = open_file(vocab_dir).read().strip().split('\n')
with open_file(vocab_dir) as fp:
# 如果是py2 則每個(gè)值都轉(zhuǎn)化為unicode
words = [native_content(_.strip()) for _ in fp.readlines()]
word_to_id = dict(zip(words, range(len(words))))
return words, word_to_id
def read_category():
"""讀取分類目錄,固定"""
categories = ['體育', '財(cái)經(jīng)', '房產(chǎn)', '家居', '教育', '科技', '時(shí)尚', '時(shí)政', '游戲', '娛樂']
categories = [native_content(x) for x in categories]
cat_to_id = dict(zip(categories, range(len(categories))))
return categories, cat_to_id
def to_words(content, words):
"""將id表示的內(nèi)容轉(zhuǎn)換為文字"""
return ''.join(words[x] for x in content)
def process_file(filename, word_to_id, cat_to_id, max_length=600):
"""將文件轉(zhuǎn)換為id表示"""
contents, labels = read_file(filename)#讀取訓(xùn)練數(shù)據(jù)的每一句話及其所對應(yīng)的類別
data_id, label_id = [], []
for i in range(len(contents)):
data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#將每句話id化
label_id.append(cat_to_id[labels[i]])#每句話對應(yīng)的類別的id
#
# # 使用keras提供的pad_sequences來將文本pad為固定長度
x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 將標(biāo)簽轉(zhuǎn)換為one-hot表示
#
return x_pad, y_pad
def batch_iter(x, y, batch_size=64):
"""生成批次數(shù)據(jù)"""
data_len = len(x)
num_batch = int((data_len - 1) / batch_size) + 1
indices = np.random.permutation(np.arange(data_len))
x_shuffle = x[indices]
y_shuffle = y[indices]
for i in range(num_batch):
start_id = i * batch_size
end_id = min((i + 1) * batch_size, data_len)
yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]
以上這篇pytorch實(shí)現(xiàn)用CNN和LSTM對文本進(jìn)行分類方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
如何向scrapy中的spider傳遞參數(shù)的幾種方法
這篇文章主要介紹了如何向scrapy中的spider傳遞參數(shù)的幾種方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11
python+opencv實(shí)現(xiàn)的簡單人臉識別代碼示例
這篇文章主要介紹了圖像識別 python+opencv的簡單人臉識別,具有一定參考價(jià)值,需要的朋友可以參考下。2017-11-11
解決Python正則表達(dá)式匹配反斜杠''''\''''問題
這篇文章主要介紹了Python正則表達(dá)式匹配反斜杠'\'問題 ,很多朋友在使用python 正則式的過程中,經(jīng)常被這個(gè)問題困擾,今天小編通過代碼給大家詳細(xì)介紹,需要的朋友可以參考下2019-07-07
Python實(shí)現(xiàn)判斷一個(gè)字符串是否包含子串的方法總結(jié)
這篇文章主要介紹了Python實(shí)現(xiàn)判斷一個(gè)字符串是否包含子串的方法,結(jié)合實(shí)例形式總結(jié)分析了四種比較常用的字符串子串判定方法,需要的朋友可以參考下2017-11-11
Python os.rename() 重命名目錄和文件的示例
今天小編就為大家分享一篇Python os.rename() 重命名目錄和文件的示例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10

