pytorch下使用LSTM神經(jīng)網(wǎng)絡(luò)寫(xiě)詩(shī)實(shí)例
在pytorch下,以數(shù)萬(wàn)首唐詩(shī)為素材,訓(xùn)練雙層LSTM神經(jīng)網(wǎng)絡(luò),使其能夠以唐詩(shī)的方式寫(xiě)詩(shī)。
代碼結(jié)構(gòu)分為四部分,分別為
1.model.py,定義了雙層LSTM模型
2.data.py,定義了從網(wǎng)上得到的唐詩(shī)數(shù)據(jù)的處理方法
3.utlis.py 定義了損失可視化的函數(shù)
4.main.py定義了模型參數(shù),以及訓(xùn)練、唐詩(shī)生成函數(shù)。
參考:電子工業(yè)出版社的《深度學(xué)習(xí)框架PyTorch:入門(mén)與實(shí)踐》第九章
main代碼及注釋如下
import sys, os import torch as t from data import get_data from model import PoetryModel from torch import nn from torch.autograd import Variable from utils import Visualizer import tqdm from torchnet import meter import ipdb class Config(object): data_path = 'data/' pickle_path = 'tang.npz' author = None constrain = None category = 'poet.tang' #or poet.song lr = 1e-3 weight_decay = 1e-4 use_gpu = True epoch = 20 batch_size = 128 maxlen = 125 plot_every = 20 #use_env = True #是否使用visodm env = 'poety' #visdom env max_gen_len = 200 debug_file = '/tmp/debugp' model_path = None prefix_words = '細(xì)雨魚(yú)兒出,微風(fēng)燕子斜。' #不是詩(shī)歌組成部分,是意境 start_words = '閑云潭影日悠悠' #詩(shī)歌開(kāi)始 acrostic = False #是否藏頭 model_prefix = 'checkpoints/tang' #模型保存路徑 opt = Config() def generate(model, start_words, ix2word, word2ix, prefix_words=None): ''' 給定幾個(gè)詞,根據(jù)這幾個(gè)詞接著生成一首完整的詩(shī)歌 ''' results = list(start_words) start_word_len = len(start_words) # 手動(dòng)設(shè)置第一個(gè)詞為<START> # 這個(gè)地方有問(wèn)題,最后需要再看一下 input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long()) if opt.use_gpu:input=input.cuda() hidden = None if prefix_words: for word in prefix_words: output,hidden = model(input,hidden) # 下邊這句話是為了把input變成1*1? input = Variable(input.data.new([word2ix[word]])).view(1,1) for i in range(opt.max_gen_len): output,hidden = model(input,hidden) if i<start_word_len: w = results[i] input = Variable(input.data.new([word2ix[w]])).view(1,1) else: top_index = output.data[0].topk(1)[1][0] w = ix2word[top_index] results.append(w) input = Variable(input.data.new([top_index])).view(1,1) if w=='<EOP>': del results[-1] #-1的意思是倒數(shù)第一個(gè) break return results def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None): ''' 生成藏頭詩(shī) start_words : u'深度學(xué)習(xí)' 生成: 深木通中岳,青苔半日脂。 度山分地險(xiǎn),逆浪到南巴。 學(xué)道兵猶毒,當(dāng)時(shí)燕不移。 習(xí)根通古岸,開(kāi)鏡出清羸。 ''' results = [] start_word_len = len(start_words) input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long()) if opt.use_gpu:input=input.cuda() hidden = None index=0 # 用來(lái)指示已經(jīng)生成了多少句藏頭詩(shī) # 上一個(gè)詞 pre_word='<START>' if prefix_words: for word in prefix_words: output,hidden = model(input,hidden) input = Variable(input.data.new([word2ix[word]])).view(1,1) for i in range(opt.max_gen_len): output,hidden = model(input,hidden) top_index = output.data[0].topk(1)[1][0] w = ix2word[top_index] if (pre_word in {u'。',u'!','<START>'} ): # 如果遇到句號(hào),藏頭的詞送進(jìn)去生成 if index==start_word_len: # 如果生成的詩(shī)歌已經(jīng)包含全部藏頭的詞,則結(jié)束 break else: # 把藏頭的詞作為輸入送入模型 w = start_words[index] index+=1 input = Variable(input.data.new([word2ix[w]])).view(1,1) else: # 否則的話,把上一次預(yù)測(cè)是詞作為下一個(gè)詞輸入 input = Variable(input.data.new([word2ix[w]])).view(1,1) results.append(w) pre_word = w return results def train(**kwargs): for k,v in kwargs.items(): setattr(opt,k,v) #設(shè)置apt里屬性的值 vis = Visualizer(env=opt.env) #獲取數(shù)據(jù) data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函數(shù) data = t.from_numpy(data) #這個(gè)地方出錯(cuò)了,是大寫(xiě)的L dataloader = t.utils.data.DataLoader(data, batch_size = opt.batch_size, shuffle = True, num_workers = 1) #在python里,這樣寫(xiě)程序可以嗎? #模型定義 model = PoetryModel(len(word2ix), 128, 256) optimizer = t.optim.Adam(model.parameters(), lr=opt.lr) criterion = nn.CrossEntropyLoss() if opt.model_path: model.load_state_dict(t.load(opt.model_path)) if opt.use_gpu: model.cuda() criterion.cuda() #The tnt.AverageValueMeter measures and returns the average value #and the standard deviation of any collection of numbers that are #added to it. It is useful, for instance, to measure the average #loss over a collection of examples. #The add() function expects as input a Lua number value, which #is the value that needs to be added to the list of values to #average. It also takes as input an optional parameter n that #assigns a weight to value in the average, in order to facilitate #computing weighted averages (default = 1). #The tnt.AverageValueMeter has no parameters to be set at initialization time. loss_meter = meter.AverageValueMeter() for epoch in range(opt.epoch): loss_meter.reset() for ii,data_ in tqdm.tqdm(enumerate(dataloader)): #tqdm是python中的進(jìn)度條 #訓(xùn)練 data_ = data_.long().transpose(1,0).contiguous() #上邊一句話,把data_變成long類(lèi)型,把1維和0維轉(zhuǎn)置,把內(nèi)存調(diào)成連續(xù)的 if opt.use_gpu: data_ = data_.cuda() optimizer.zero_grad() input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:]) #上邊一句,將輸入的詩(shī)句錯(cuò)開(kāi)一個(gè)字,形成訓(xùn)練和目標(biāo) output,_ = model(input_) loss = criterion(output, target.view(-1)) loss.backward() optimizer.step() loss_meter.add(loss.data[0]) #為什么是data[0]? #可視化用到的是utlis.py里的函數(shù) if (1+ii)%opt.plot_every ==0: if os.path.exists(opt.debug_file): ipdb.set_trace() vis.plot('loss',loss_meter.value()[0]) # 下面是對(duì)目前模型情況的測(cè)試,詩(shī)歌原文 poetrys = [[ix2word[_word] for _word in data_[:,_iii]] for _iii in range(data_.size(1))][:16] #上面句子嵌套了兩個(gè)循環(huán),主要是將詩(shī)歌索引的前十六個(gè)字變成原文 vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),win = u'origin_poem') gen_poetries = [] #分別以以下幾個(gè)字作為詩(shī)歌的第一個(gè)字,生成8首詩(shī) for word in list(u'春江花月夜涼如水'): gen_poetry = ''.join(generate(model,word,ix2word,word2ix)) gen_poetries.append(gen_poetry) vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]), win = u'gen_poem') t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch)) def gen(**kwargs): ''' 提供命令行接口,用以生成相應(yīng)的詩(shī) ''' for k,v in kwargs.items(): setattr(opt,k,v) data, word2ix, ix2word = get_data(opt) model = PoetryModel(len(word2ix), 128, 256) map_location = lambda s,l:s # 上邊句子里的map_location是在load里用的,用以加載到指定的CPU或GPU, # 上邊句子的意思是將模型加載到默認(rèn)的GPU上 state_dict = t.load(opt.model_path, map_location = map_location) model.load_state_dict(state_dict) if opt.use_gpu: model.cuda() if sys.version_info.major == 3: if opt.start_words.insprintable(): start_words = opt.start_words prefix_words = opt.prefix_words if opt.prefix_words else None else: start_words = opt.start_words.encode('ascii',\ 'surrogateescape').decode('utf8') prefix_words = opt.prefix_words.encode('ascii',\ 'surrogateescape').decode('utf8') if opt.prefix_words else None start_words = start_words.replace(',',u',')\ .replace('.',u'。')\ .replace('?',u'?') gen_poetry = gen_acrostic if opt.acrostic else generate result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words) print(''.join(result)) if __name__ == '__main__': import fire fire.Fire()
以上代碼給我一些經(jīng)驗(yàn),
1. 了解python的編程方式,如空格、換行等;進(jìn)一步了解python的各個(gè)基本模塊;
2. 可能出的錯(cuò)誤:函數(shù)名寫(xiě)錯(cuò),大小寫(xiě),變量名寫(xiě)錯(cuò),括號(hào)不全。
3. 對(duì)cuda()的用法有了進(jìn)一步認(rèn)識(shí);
4. 學(xué)會(huì)了調(diào)試程序(fire);
5. 學(xué)會(huì)了訓(xùn)練結(jié)果的可視化(visdom);
6. 進(jìn)一步的了解了LSTM,對(duì)深度學(xué)習(xí)的架構(gòu)、實(shí)現(xiàn)有了宏觀把控。
這篇pytorch下使用LSTM神經(jīng)網(wǎng)絡(luò)寫(xiě)詩(shī)實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于PyQt4和PySide實(shí)現(xiàn)輸入對(duì)話框效果
這篇文章主要為大家詳細(xì)介紹了基于PyQt4和PySide實(shí)現(xiàn)輸入對(duì)話框效果,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-02-02Python用模塊pytz來(lái)轉(zhuǎn)換時(shí)區(qū)
在Python中,與時(shí)間相關(guān)的庫(kù)有好些,可以幫助我們快速的處理與時(shí)間相關(guān)的需求和問(wèn)題。這里想和大家分享一下如何在Python用模塊pytz來(lái)轉(zhuǎn)換時(shí)區(qū)。2016-08-08Python使用for生成列表實(shí)現(xiàn)過(guò)程解析
這篇文章主要介紹了Python使用for生成列表實(shí)現(xiàn)過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-09-09Python錯(cuò)誤: SyntaxError: Non-ASCII character解決辦法
這篇文章主要介紹了Python錯(cuò)誤: SyntaxError: Non-ASCII character解決辦法的相關(guān)資料,需要的朋友可以參考下2017-06-06Python 中如何使用 setLevel() 設(shè)置日志級(jí)別
這篇文章主要介紹了在 Python 中使用setLevel() 設(shè)置日志級(jí)別,Python 提供了一個(gè)單獨(dú)的日志記錄模塊作為其標(biāo)準(zhǔn)庫(kù)的一部分,以簡(jiǎn)化日志記錄,本文將討論日志記錄 setLevel 及其在 Python 中的工作方式,需要的朋友可以參考下2023-07-07Python基于tkinter canvas實(shí)現(xiàn)圖片裁剪功能
這篇文章主要介紹了Python基于tkinter canvas實(shí)現(xiàn)圖片裁剪功能,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11詳解如何使用Python的Plotly庫(kù)進(jìn)行交互式圖形可視化
Python中有許多強(qiáng)大的工具和庫(kù)可用于創(chuàng)建交互式圖形,其中之一就是Plotly庫(kù),Plotly庫(kù)提供了豐富的功能和靈活的接口,使得創(chuàng)建各種類(lèi)型的交互式圖形變得簡(jiǎn)單而直觀,本文將介紹如何使用Plotly庫(kù)來(lái)創(chuàng)建交互式圖形,需要的朋友可以參考下2024-05-05Python 類(lèi)中引用其他類(lèi)的實(shí)現(xiàn)示例
在Python中,類(lèi)的引用是通過(guò)屬性或方法與其他類(lèi)實(shí)例關(guān)聯(lián),實(shí)現(xiàn)復(fù)雜邏輯,本文介紹了關(guān)聯(lián)、組合等類(lèi)之間的引用方式,具有一定的參考價(jià)值,感興趣的可以了解一下2024-09-09Python判斷字符串是否為字母或者數(shù)字(浮點(diǎn)數(shù))的多種方法
本文給大家?guī)?lái)三種方法基于Python判斷字符串是否為字母或者數(shù)字(浮點(diǎn)數(shù)),非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2018-08-08