Python深度強(qiáng)化學(xué)習(xí)之DQN算法原理詳解
DQN算法是DeepMind團(tuán)隊(duì)提出的一種深度強(qiáng)化學(xué)習(xí)算法,在許多電動(dòng)游戲中達(dá)到人類(lèi)玩家甚至超越人類(lèi)玩家的水準(zhǔn),本文就帶領(lǐng)大家了解一下這個(gè)算法,論文的鏈接見(jiàn)下方。
論文:Human-level control through deep reinforcement learning | Nature
代碼:后續(xù)會(huì)將代碼上傳到Github上...
1 DQN算法簡(jiǎn)介
Q-learning算法采用一個(gè)Q-tabel來(lái)記錄每個(gè)狀態(tài)下的動(dòng)作值,當(dāng)狀態(tài)空間或動(dòng)作空間較大時(shí),需要的存儲(chǔ)空間也會(huì)較大。如果狀態(tài)空間或動(dòng)作空間連續(xù),則該算法無(wú)法使用。因此,Q-learning算法只能用于解決離散低維狀態(tài)空間和動(dòng)作空間類(lèi)問(wèn)題。DQN算法的核心就是用一個(gè)人工神經(jīng)網(wǎng)絡(luò)來(lái)代替Q-tabel,即動(dòng)作價(jià)值函數(shù)。網(wǎng)絡(luò)的輸入為狀態(tài)信息,輸出為每個(gè)動(dòng)作的價(jià)值,因此DQN算法可以用來(lái)解決連續(xù)狀態(tài)空間和離散動(dòng)作空間問(wèn)題,無(wú)法解決連續(xù)動(dòng)作空間類(lèi)問(wèn)題。針對(duì)連續(xù)動(dòng)作空間類(lèi)問(wèn)題,后面blog會(huì)慢慢介紹。
2 DQN算法原理
DQN算法是一種off-policy算法,當(dāng)同時(shí)出現(xiàn)異策、自益和函數(shù)近似時(shí),無(wú)法保證收斂性,容易出現(xiàn)訓(xùn)練不穩(wěn)定或訓(xùn)練困難等問(wèn)題。針對(duì)這些問(wèn)題,研究人員主要從以下兩個(gè)方面進(jìn)行了改進(jìn)。
(1)經(jīng)驗(yàn)回放:將經(jīng)驗(yàn)(當(dāng)前狀態(tài)st、動(dòng)作at、即時(shí)獎(jiǎng)勵(lì)rt+1、下個(gè)狀態(tài)st+1、回合狀態(tài)done)存放在經(jīng)驗(yàn)池中,并按照一定的規(guī)則采樣。
(2)目標(biāo)網(wǎng)絡(luò):修改網(wǎng)絡(luò)的更新方式,例如不把剛學(xué)習(xí)到的網(wǎng)絡(luò)權(quán)重馬上用于后續(xù)的自益過(guò)程。
2.1 經(jīng)驗(yàn)回放
經(jīng)驗(yàn)回放就是一種讓經(jīng)驗(yàn)概率分布變得穩(wěn)定的技術(shù),可以提高訓(xùn)練的穩(wěn)定性。經(jīng)驗(yàn)回放主要有“存儲(chǔ)”和“回放”兩大關(guān)鍵步驟:
存儲(chǔ):將經(jīng)驗(yàn)以(st,at,rt+1,st+1,done)形式存儲(chǔ)在經(jīng)驗(yàn)池中。
回放:按照某種規(guī)則從經(jīng)驗(yàn)池中采樣一條或多條經(jīng)驗(yàn)數(shù)據(jù)。
從存儲(chǔ)的角度來(lái)看,經(jīng)驗(yàn)回放可以分為集中式回放和分布式回放:
- 集中式回放:智能體在一個(gè)環(huán)境中運(yùn)行,把經(jīng)驗(yàn)統(tǒng)一存儲(chǔ)在經(jīng)驗(yàn)池中。
- 分布式回放:多個(gè)智能體同時(shí)在多個(gè)環(huán)境中運(yùn)行,并將經(jīng)驗(yàn)統(tǒng)一存儲(chǔ)在經(jīng)驗(yàn)池中。由于多個(gè)智能體同時(shí)生成經(jīng)驗(yàn),所以能夠使用更多資源的同時(shí)更快地收集經(jīng)驗(yàn)。
從采樣的角度來(lái)看,經(jīng)驗(yàn)回放可以分為均勻回放和優(yōu)先回放:
- 均勻回放:等概率從經(jīng)驗(yàn)池中采樣經(jīng)驗(yàn)。
- 優(yōu)先回放:為經(jīng)驗(yàn)池中每條經(jīng)驗(yàn)指定一個(gè)優(yōu)先級(jí),在采樣經(jīng)驗(yàn)時(shí)更傾向于選擇優(yōu)先級(jí)更高的經(jīng)驗(yàn)。一般的做法是,如果某條經(jīng)驗(yàn)(例如經(jīng)驗(yàn))的優(yōu)先級(jí)為,那么選取該經(jīng)驗(yàn)的概率為:
優(yōu)先回放可以具體參照這篇論文:優(yōu)先經(jīng)驗(yàn)回放
經(jīng)驗(yàn)回放的優(yōu)點(diǎn):
1.在訓(xùn)練Q網(wǎng)絡(luò)時(shí),可以打破數(shù)據(jù)之間的相關(guān)性,使得數(shù)據(jù)滿(mǎn)足獨(dú)立同分布,從而減小參數(shù)更新的方差,提高收斂速度。
2.能夠重復(fù)使用經(jīng)驗(yàn),數(shù)據(jù)利用率高,對(duì)于數(shù)據(jù)獲取困難的情況尤其有用。
經(jīng)驗(yàn)回放的缺點(diǎn):
無(wú)法應(yīng)用于回合更新和多步學(xué)習(xí)算法。但是將經(jīng)驗(yàn)回放應(yīng)用于Q學(xué)習(xí),就規(guī)避了這個(gè)缺點(diǎn)。
代碼中采用集中式均勻回放,具體如下:
import numpy as np class ReplayBuffer: def __init__(self, state_dim, action_dim, max_size, batch_size): self.mem_size = max_size self.batch_size = batch_size self.mem_cnt = 0 self.state_memory = np.zeros((self.mem_size, state_dim)) self.action_memory = np.zeros((self.mem_size, )) self.reward_memory = np.zeros((self.mem_size, )) self.next_state_memory = np.zeros((self.mem_size, state_dim)) self.terminal_memory = np.zeros((self.mem_size, ), dtype=np.bool) def store_transition(self, state, action, reward, state_, done): mem_idx = self.mem_cnt % self.mem_size self.state_memory[mem_idx] = state self.action_memory[mem_idx] = action self.reward_memory[mem_idx] = reward self.next_state_memory[mem_idx] = state_ self.terminal_memory[mem_idx] = done self.mem_cnt += 1 def sample_buffer(self): mem_len = min(self.mem_size, self.mem_cnt) batch = np.random.choice(mem_len, self.batch_size, replace=True) states = self.state_memory[batch] actions = self.action_memory[batch] rewards = self.reward_memory[batch] states_ = self.next_state_memory[batch] terminals = self.terminal_memory[batch] return states, actions, rewards, states_, terminals def ready(self): return self.mem_cnt > self.batch_size
2.2 目標(biāo)網(wǎng)絡(luò)
對(duì)于基于自益的Q學(xué)習(xí),動(dòng)作價(jià)值估計(jì)和權(quán)重有關(guān)。當(dāng)權(quán)重變化時(shí),動(dòng)作價(jià)值的估計(jì)也會(huì)發(fā)生變化。在學(xué)習(xí)的過(guò)程中,動(dòng)作價(jià)值試圖追逐一個(gè)變化的回報(bào),容易出現(xiàn)不穩(wěn)定的情況。
目標(biāo)網(wǎng)絡(luò)是在原有的神經(jīng)網(wǎng)絡(luò)之外重新搭建一個(gè)結(jié)構(gòu)完全相同的網(wǎng)絡(luò)。原先的網(wǎng)絡(luò)稱(chēng)為評(píng)估網(wǎng)絡(luò),新構(gòu)建的網(wǎng)絡(luò)稱(chēng)為目標(biāo)網(wǎng)絡(luò)。在學(xué)習(xí)過(guò)程中,使用目標(biāo)網(wǎng)絡(luò)進(jìn)行自益得到回報(bào)的評(píng)估值,作為學(xué)習(xí)目標(biāo)。在更新過(guò)程中,只更新評(píng)估網(wǎng)絡(luò)的權(quán)重,而不更新目標(biāo)網(wǎng)絡(luò)的權(quán)重。這樣,更新權(quán)重時(shí)針對(duì)的目標(biāo)不會(huì)在每次迭代都發(fā)生變化,是一個(gè)固定的目標(biāo)。在更新一定次數(shù)后,再將評(píng)估網(wǎng)絡(luò)的權(quán)重復(fù)制給目標(biāo)網(wǎng)絡(luò),進(jìn)而進(jìn)行下一批更新,這樣目標(biāo)網(wǎng)絡(luò)也能得到更新。由于在目標(biāo)網(wǎng)絡(luò)沒(méi)有變化的一段時(shí)間內(nèi)回報(bào)的估計(jì)是相對(duì)固定的,因此目標(biāo)網(wǎng)絡(luò)的引入增加了學(xué)習(xí)的穩(wěn)定性。
目標(biāo)網(wǎng)絡(luò)的更新方式:
上述在一段時(shí)間內(nèi)固定目標(biāo)網(wǎng)絡(luò),一定次數(shù)后將評(píng)估網(wǎng)絡(luò)權(quán)重復(fù)制給目標(biāo)網(wǎng)絡(luò)的更新方式為硬更新(hard update),即
其中表示目標(biāo)網(wǎng)絡(luò)權(quán)重,表示評(píng)估網(wǎng)絡(luò)權(quán)重。
另外一種常用的更新方式為軟更新(soft update),即引入一個(gè)學(xué)習(xí)率,將舊的目標(biāo)網(wǎng)絡(luò)參數(shù)和新的評(píng)估網(wǎng)絡(luò)參數(shù)直接做加權(quán)平均后的值賦值給目標(biāo)網(wǎng)絡(luò)
學(xué)習(xí)率
3 DQN算法偽代碼
DQN算法的實(shí)現(xiàn)代碼為:
import torch as T import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import numpy as np from buffer import ReplayBuffer device = T.device("cuda:0" if T.cuda.is_available() else "cpu") class DeepQNetwork(nn.Module): def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim): super(DeepQNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, fc1_dim) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.q = nn.Linear(fc2_dim, action_dim) self.optimizer = optim.Adam(self.parameters(), lr=alpha) self.to(device) def forward(self, state): x = T.relu(self.fc1(state)) x = T.relu(self.fc2(x)) q = self.q(x) return q def save_checkpoint(self, checkpoint_file): T.save(self.state_dict(), checkpoint_file, _use_new_zipfile_serialization=False) def load_checkpoint(self, checkpoint_file): self.load_state_dict(T.load(checkpoint_file)) class DQN: def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim, ckpt_dir, gamma=0.99, tau=0.005, epsilon=1.0, eps_end=0.01, eps_dec=5e-4, max_size=1000000, batch_size=256): self.tau = tau self.gamma = gamma self.epsilon = epsilon self.eps_min = eps_end self.eps_dec = eps_dec self.batch_size = batch_size self.action_space = [i for i in range(action_dim)] self.checkpoint_dir = ckpt_dir self.q_eval = DeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim, fc1_dim=fc1_dim, fc2_dim=fc2_dim) self.q_target = DeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim, fc1_dim=fc1_dim, fc2_dim=fc2_dim) self.memory = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, max_size=max_size, batch_size=batch_size) self.update_network_parameters(tau=1.0) def update_network_parameters(self, tau=None): if tau is None: tau = self.tau for q_target_params, q_eval_params in zip(self.q_target.parameters(), self.q_eval.parameters()): q_target_params.data.copy_(tau * q_eval_params + (1 - tau) * q_target_params) def remember(self, state, action, reward, state_, done): self.memory.store_transition(state, action, reward, state_, done) def choose_action(self, observation, isTrain=True): state = T.tensor([observation], dtype=T.float).to(device) actions = self.q_eval.forward(state) action = T.argmax(actions).item() if (np.random.random() < self.epsilon) and isTrain: action = np.random.choice(self.action_space) return action def learn(self): if not self.memory.ready(): return states, actions, rewards, next_states, terminals = self.memory.sample_buffer() batch_idx = np.arange(self.batch_size) states_tensor = T.tensor(states, dtype=T.float).to(device) rewards_tensor = T.tensor(rewards, dtype=T.float).to(device) next_states_tensor = T.tensor(next_states, dtype=T.float).to(device) terminals_tensor = T.tensor(terminals).to(device) with T.no_grad(): q_ = self.q_target.forward(next_states_tensor) q_[terminals_tensor] = 0.0 target = rewards_tensor + self.gamma * T.max(q_, dim=-1)[0] q = self.q_eval.forward(states_tensor)[batch_idx, actions] loss = F.mse_loss(q, target.detach()) self.q_eval.optimizer.zero_grad() loss.backward() self.q_eval.optimizer.step() self.update_network_parameters() self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min def save_models(self, episode): self.q_eval.save_checkpoint(self.checkpoint_dir + 'Q_eval/DQN_q_eval_{}.pth'.format(episode)) print('Saving Q_eval network successfully!') self.q_target.save_checkpoint(self.checkpoint_dir + 'Q_target/DQN_Q_target_{}.pth'.format(episode)) print('Saving Q_target network successfully!') def load_models(self, episode): self.q_eval.load_checkpoint(self.checkpoint_dir + 'Q_eval/DQN_q_eval_{}.pth'.format(episode)) print('Loading Q_eval network successfully!') self.q_target.load_checkpoint(self.checkpoint_dir + 'Q_target/DQN_Q_target_{}.pth'.format(episode)) print('Loading Q_target network successfully!')
算法仿真環(huán)境是在gym庫(kù)中的LunarLander-v2環(huán)境,因此需要先配置好gym庫(kù)。進(jìn)入Aanconda中對(duì)應(yīng)的Python環(huán)境中,執(zhí)行下面的指令
pip install gym
但是,這樣安裝的gym庫(kù)只包括少量的內(nèi)置環(huán)境,如算法環(huán)境、簡(jiǎn)單文字游戲環(huán)境和經(jīng)典控制環(huán)境,無(wú)法使用LunarLander-v2。
訓(xùn)練腳本如下:
import gym import numpy as np import argparse from DQN import DQN from utils import plot_learning_curve, create_directory parser = argparse.ArgumentParser() parser.add_argument('--max_episodes', type=int, default=500) parser.add_argument('--ckpt_dir', type=str, default='./checkpoints/DQN/') parser.add_argument('--reward_path', type=str, default='./output_images/avg_reward.png') parser.add_argument('--epsilon_path', type=str, default='./output_images/epsilon.png') args = parser.parse_args() def main(): env = gym.make('LunarLander-v2') agent = DQN(alpha=0.0003, state_dim=env.observation_space.shape[0], action_dim=env.action_space.n, fc1_dim=256, fc2_dim=256, ckpt_dir=args.ckpt_dir, gamma=0.99, tau=0.005, epsilon=1.0, eps_end=0.05, eps_dec=5e-4, max_size=1000000, batch_size=256) create_directory(args.ckpt_dir, sub_dirs=['Q_eval', 'Q_target']) total_rewards, avg_rewards, eps_history = [], [], [] for episode in range(args.max_episodes): total_reward = 0 done = False observation = env.reset() while not done: action = agent.choose_action(observation, isTrain=True) observation_, reward, done, info = env.step(action) agent.remember(observation, action, reward, observation_, done) agent.learn() total_reward += reward observation = observation_ total_rewards.append(total_reward) avg_reward = np.mean(total_rewards[-100:]) avg_rewards.append(avg_reward) eps_history.append(agent.epsilon) print('EP:{} reward:{} avg_reward:{} epsilon:{}'. format(episode + 1, total_reward, avg_reward, agent.epsilon)) if (episode + 1) % 50 == 0: agent.save_models(episode + 1) episodes = [i for i in range(args.max_episodes)] plot_learning_curve(episodes, avg_rewards, 'Reward', 'reward', args.reward_path) plot_learning_curve(episodes, eps_history, 'Epsilon', 'epsilon', args.epsilon_path) if __name__ == '__main__': main()
訓(xùn)練時(shí)還會(huì)用到畫(huà)圖函數(shù)和創(chuàng)建文件夾函數(shù),我將他們另外放在一個(gè)utils.py腳本中,具體代碼如下:
import os import matplotlib.pyplot as plt def plot_learning_curve(episodes, records, title, ylabel, figure_file): plt.figure() plt.plot(episodes, records, linestyle='-', color='r') plt.title(title) plt.xlabel('episode') plt.ylabel(ylabel) plt.show() plt.savefig(figure_file) def create_directory(path: str, sub_dirs: list): for sub_dir in sub_dirs: if os.path.exists(path + sub_dir): print(path + sub_dir + ' is already exist!') else: os.makedirs(path + sub_dir, exist_ok=True) print(path + sub_dir + ' create successfully!')
仿真結(jié)果如下圖所示:
通過(guò)平均獎(jiǎng)勵(lì)曲線(xiàn)可以看出,大概迭代到400步左右時(shí)算法趨于收斂。?
到此這篇關(guān)于Python深度強(qiáng)化學(xué)習(xí)之DQN算法原理詳解的文章就介紹到這了,更多相關(guān)Python DQN算法內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Django高并發(fā)負(fù)載均衡實(shí)現(xiàn)原理詳解
這篇文章主要介紹了Django高并發(fā)負(fù)載均衡實(shí)現(xiàn)原理詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04關(guān)于matplotlib及相關(guān)cmap參數(shù)的取值方式
這篇文章主要介紹了關(guān)于matplotlib及相關(guān)cmap參數(shù)的取值方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-11-11python實(shí)現(xiàn)H2O中的隨機(jī)森林算法介紹及其項(xiàng)目實(shí)戰(zhàn)
這篇文章主要介紹了python實(shí)現(xiàn)H2O中的隨機(jī)森林算法介紹及其項(xiàng)目實(shí)戰(zhàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08python實(shí)現(xiàn)簡(jiǎn)易計(jì)算器功能
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)簡(jiǎn)易計(jì)算器功能,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-02-02python深度學(xué)習(xí)人工智能BackPropagation鏈?zhǔn)椒▌t
這篇文章主要為大家介紹了python深度學(xué)習(xí)人工智能BackPropagation鏈?zhǔn)椒▌t的示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助2021-11-11在python3.64中安裝pyinstaller庫(kù)的方法步驟
這篇文章主要介紹了在python3.64中安裝pyinstaller庫(kù)的方法步驟,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-06-06PyCharm設(shè)置每行最大長(zhǎng)度限制的方法
今天小編就為大家分享一篇PyCharm設(shè)置每行最大長(zhǎng)度限制的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01PyQt5實(shí)現(xiàn)QLineEdit正則表達(dá)式輸入驗(yàn)證器
這篇文章主要介紹了PyQt5實(shí)現(xiàn)QLineEdit正則表達(dá)式輸入驗(yàn)證器,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04