pytorch下使用LSTM神經(jīng)網(wǎng)絡(luò)寫詩實例
在pytorch下,以數(shù)萬首唐詩為素材,訓(xùn)練雙層LSTM神經(jīng)網(wǎng)絡(luò),使其能夠以唐詩的方式寫詩。
代碼結(jié)構(gòu)分為四部分,分別為
1.model.py,定義了雙層LSTM模型
2.data.py,定義了從網(wǎng)上得到的唐詩數(shù)據(jù)的處理方法
3.utlis.py 定義了損失可視化的函數(shù)
4.main.py定義了模型參數(shù),以及訓(xùn)練、唐詩生成函數(shù)。
參考:電子工業(yè)出版社的《深度學(xué)習(xí)框架PyTorch:入門與實踐》第九章
main代碼及注釋如下
import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
class Config(object):
data_path = 'data/'
pickle_path = 'tang.npz'
author = None
constrain = None
category = 'poet.tang' #or poet.song
lr = 1e-3
weight_decay = 1e-4
use_gpu = True
epoch = 20
batch_size = 128
maxlen = 125
plot_every = 20
#use_env = True #是否使用visodm
env = 'poety'
#visdom env
max_gen_len = 200
debug_file = '/tmp/debugp'
model_path = None
prefix_words = '細雨魚兒出,微風(fēng)燕子斜。'
#不是詩歌組成部分,是意境
start_words = '閑云潭影日悠悠'
#詩歌開始
acrostic = False
#是否藏頭
model_prefix = 'checkpoints/tang'
#模型保存路徑
opt = Config()
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
'''
給定幾個詞,根據(jù)這幾個詞接著生成一首完整的詩歌
'''
results = list(start_words)
start_word_len = len(start_words)
# 手動設(shè)置第一個詞為<START>
# 這個地方有問題,最后需要再看一下
input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
if opt.use_gpu:input=input.cuda()
hidden = None
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
# 下邊這句話是為了把input變成1*1?
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(opt.max_gen_len):
output,hidden = model(input,hidden)
if i<start_word_len:
w = results[i]
input = Variable(input.data.new([word2ix[w]])).view(1,1)
else:
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index]
results.append(w)
input = Variable(input.data.new([top_index])).view(1,1)
if w=='<EOP>':
del results[-1] #-1的意思是倒數(shù)第一個
break
return results
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
'''
生成藏頭詩
start_words : u'深度學(xué)習(xí)'
生成:
深木通中岳,青苔半日脂。
度山分地險,逆浪到南巴。
學(xué)道兵猶毒,當(dāng)時燕不移。
習(xí)根通古岸,開鏡出清羸。
'''
results = []
start_word_len = len(start_words)
input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
if opt.use_gpu:input=input.cuda()
hidden = None
index=0 # 用來指示已經(jīng)生成了多少句藏頭詩
# 上一個詞
pre_word='<START>'
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(opt.max_gen_len):
output,hidden = model(input,hidden)
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index]
if (pre_word in {u'。',u'!','<START>'} ):
# 如果遇到句號,藏頭的詞送進去生成
if index==start_word_len:
# 如果生成的詩歌已經(jīng)包含全部藏頭的詞,則結(jié)束
break
else:
# 把藏頭的詞作為輸入送入模型
w = start_words[index]
index+=1
input = Variable(input.data.new([word2ix[w]])).view(1,1)
else:
# 否則的話,把上一次預(yù)測是詞作為下一個詞輸入
input = Variable(input.data.new([word2ix[w]])).view(1,1)
results.append(w)
pre_word = w
return results
def train(**kwargs):
for k,v in kwargs.items():
setattr(opt,k,v) #設(shè)置apt里屬性的值
vis = Visualizer(env=opt.env)
#獲取數(shù)據(jù)
data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函數(shù)
data = t.from_numpy(data)
#這個地方出錯了,是大寫的L
dataloader = t.utils.data.DataLoader(data,
batch_size = opt.batch_size,
shuffle = True,
num_workers = 1) #在python里,這樣寫程序可以嗎?
#模型定義
model = PoetryModel(len(word2ix), 128, 256)
optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
criterion = nn.CrossEntropyLoss()
if opt.model_path:
model.load_state_dict(t.load(opt.model_path))
if opt.use_gpu:
model.cuda()
criterion.cuda()
#The tnt.AverageValueMeter measures and returns the average value
#and the standard deviation of any collection of numbers that are
#added to it. It is useful, for instance, to measure the average
#loss over a collection of examples.
#The add() function expects as input a Lua number value, which
#is the value that needs to be added to the list of values to
#average. It also takes as input an optional parameter n that
#assigns a weight to value in the average, in order to facilitate
#computing weighted averages (default = 1).
#The tnt.AverageValueMeter has no parameters to be set at initialization time.
loss_meter = meter.AverageValueMeter()
for epoch in range(opt.epoch):
loss_meter.reset()
for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
#tqdm是python中的進度條
#訓(xùn)練
data_ = data_.long().transpose(1,0).contiguous()
#上邊一句話,把data_變成long類型,把1維和0維轉(zhuǎn)置,把內(nèi)存調(diào)成連續(xù)的
if opt.use_gpu: data_ = data_.cuda()
optimizer.zero_grad()
input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
#上邊一句,將輸入的詩句錯開一個字,形成訓(xùn)練和目標(biāo)
output,_ = model(input_)
loss = criterion(output, target.view(-1))
loss.backward()
optimizer.step()
loss_meter.add(loss.data[0]) #為什么是data[0]?
#可視化用到的是utlis.py里的函數(shù)
if (1+ii)%opt.plot_every ==0:
if os.path.exists(opt.debug_file):
ipdb.set_trace()
vis.plot('loss',loss_meter.value()[0])
# 下面是對目前模型情況的測試,詩歌原文
poetrys = [[ix2word[_word] for _word in data_[:,_iii]]
for _iii in range(data_.size(1))][:16]
#上面句子嵌套了兩個循環(huán),主要是將詩歌索引的前十六個字變成原文
vis.text('</br>'.join([''.join(poetry) for poetry in
poetrys]),win = u'origin_poem')
gen_poetries = []
#分別以以下幾個字作為詩歌的第一個字,生成8首詩
for word in list(u'春江花月夜涼如水'):
gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
gen_poetries.append(gen_poetry)
vis.text('</br>'.join([''.join(poetry) for poetry in
gen_poetries]), win = u'gen_poem')
t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))
def gen(**kwargs):
'''
提供命令行接口,用以生成相應(yīng)的詩
'''
for k,v in kwargs.items():
setattr(opt,k,v)
data, word2ix, ix2word = get_data(opt)
model = PoetryModel(len(word2ix), 128, 256)
map_location = lambda s,l:s
# 上邊句子里的map_location是在load里用的,用以加載到指定的CPU或GPU,
# 上邊句子的意思是將模型加載到默認(rèn)的GPU上
state_dict = t.load(opt.model_path, map_location = map_location)
model.load_state_dict(state_dict)
if opt.use_gpu:
model.cuda()
if sys.version_info.major == 3:
if opt.start_words.insprintable():
start_words = opt.start_words
prefix_words = opt.prefix_words if opt.prefix_words else None
else:
start_words = opt.start_words.encode('ascii',\
'surrogateescape').decode('utf8')
prefix_words = opt.prefix_words.encode('ascii',\
'surrogateescape').decode('utf8') if opt.prefix_words else None
start_words = start_words.replace(',',u',')\
.replace('.',u'。')\
.replace('?',u'?')
gen_poetry = gen_acrostic if opt.acrostic else generate
result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
print(''.join(result))
if __name__ == '__main__':
import fire
fire.Fire()
以上代碼給我一些經(jīng)驗,
1. 了解python的編程方式,如空格、換行等;進一步了解python的各個基本模塊;
2. 可能出的錯誤:函數(shù)名寫錯,大小寫,變量名寫錯,括號不全。
3. 對cuda()的用法有了進一步認(rèn)識;
4. 學(xué)會了調(diào)試程序(fire);
5. 學(xué)會了訓(xùn)練結(jié)果的可視化(visdom);
6. 進一步的了解了LSTM,對深度學(xué)習(xí)的架構(gòu)、實現(xiàn)有了宏觀把控。
這篇pytorch下使用LSTM神經(jīng)網(wǎng)絡(luò)寫詩實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python錯誤: SyntaxError: Non-ASCII character解決辦法
這篇文章主要介紹了Python錯誤: SyntaxError: Non-ASCII character解決辦法的相關(guān)資料,需要的朋友可以參考下2017-06-06
Python 中如何使用 setLevel() 設(shè)置日志級別
這篇文章主要介紹了在 Python 中使用setLevel() 設(shè)置日志級別,Python 提供了一個單獨的日志記錄模塊作為其標(biāo)準(zhǔn)庫的一部分,以簡化日志記錄,本文將討論日志記錄 setLevel 及其在 Python 中的工作方式,需要的朋友可以參考下2023-07-07
Python基于tkinter canvas實現(xiàn)圖片裁剪功能
這篇文章主要介紹了Python基于tkinter canvas實現(xiàn)圖片裁剪功能,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-11-11
詳解如何使用Python的Plotly庫進行交互式圖形可視化
Python中有許多強大的工具和庫可用于創(chuàng)建交互式圖形,其中之一就是Plotly庫,Plotly庫提供了豐富的功能和靈活的接口,使得創(chuàng)建各種類型的交互式圖形變得簡單而直觀,本文將介紹如何使用Plotly庫來創(chuàng)建交互式圖形,需要的朋友可以參考下2024-05-05
Python判斷字符串是否為字母或者數(shù)字(浮點數(shù))的多種方法
本文給大家?guī)砣N方法基于Python判斷字符串是否為字母或者數(shù)字(浮點數(shù)),非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2018-08-08

