教你用pytorch訓(xùn)練五子棋ai示例代碼
有4個(gè)文件
game.py 五子棋游戲
mod.py 神經(jīng)網(wǎng)絡(luò)模型
xl.py 訓(xùn)練的代碼
aigame.py 玩家與對(duì)戰(zhàn)的五子棋
game.py
class Game:
def __init__(self, h, w):
# 行數(shù)
self.h = h
# 列數(shù)
self.w = w
# 棋盤
self.L = [['-' for _ in range(w)] for _ in range(h)]
# 當(dāng)前玩家 - 表示空 X先下 然后是O
self.cur = 'X'
# 游戲勝利者
self.win_user = None
# 檢查下完這步后有沒(méi)有贏 y是行 x是列 返回True表示贏
def check_win(self, y, x):
directions = [
# 水平、垂直、兩個(gè)對(duì)角線方向
(1, 0), (0, 1), (1, 1), (1, -1)
]
player = self.L[y][x]
for dy, dx in directions:
count = 0
# 檢查四個(gè)方向上的連續(xù)相同棋子
for i in range(-4, 5): # 檢查-4到4的范圍,因?yàn)槲遄舆B珠需要5個(gè)棋子
ny, nx = y + i * dy, x + i * dx
if 0 <= ny < self.h and 0 <= nx < self.w and self.L[ny][nx] == player:
count += 1
if count == 5:
return True
else:
count = 0
return False
# 檢查能不能下這里 y行 x列 返回True表示能下
def check(self, y, x):
return self.L[y][x] == '-' and self.win_user is None
# 打印棋盤 可視化用得到
def __str__(self):
# 確定行號(hào)和列號(hào)的寬度
row_width = len(str(self.h - 1))
col_width = len(str(self.w - 1))
# 生成帶有行號(hào)和列號(hào)的棋盤字符串表示
result = []
# 添加列號(hào)標(biāo)題
result.append(' ' * (row_width + 1) + ' '.join(f'{i:>{col_width}}' for i in range(self.w)))
# 添加分隔線(可選)
result.append(' ' * (row_width + 1) + '-' * (col_width * self.w))
# 添加棋盤行
for y, row in enumerate(self.L):
# 添加行號(hào)
result.append(f'{y:>{row_width}} ' + ' '.join(f'{cell:>{col_width}}' for cell in row))
return '\n'.join(result)
# 一步棋
def set(self, y, x):
if self.win_user or not self.check(y, x):
return False
self.L[y][x] = self.cur
if self.check_win(y, x):
self.win_user = self.cur
return True
self.cur = 'X' if self.cur == 'O' else 'O'
return True
#和棋
def heqi(self):
for y in range(self.h):
for x in range(self.w):
if self.L[y][x]=='-':
return False
return True
#玩家自己下
def run_game01():
g = Game(15, 15)
while not g.win_user:
# 打印當(dāng)前棋盤狀態(tài)
while 1:
print(g)
try:
y,x=input(g.cur+':').split(',')
x=int(x)
y=int(y)
if g.set(y,x):
break
except Exception as e:
print(e)
print(g)
print('勝利者',g.win_user)
mod.py
import torch
import torch.nn as nn
import torch.optim as optim
from game import Game
class MyMod(nn.Module):
def __init__(self, input_channels=1, output_size=15*15):
super(MyMod, self).__init__()
# 定義卷積層,用于提取特征
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) # 輸出 32 x 15 x 15
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 輸出 64 x 15 x 15
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 輸出 128 x 15 x 15
# 定義全連接層,用于最后的得分預(yù)測(cè)
self.fc1 = nn.Linear(128 * 15 * 15, 1024) # 展平后傳入全連接層
self.fc2 = nn.Linear(1024, output_size) # 輸出 15*15 的得分預(yù)測(cè)
def forward(self, x):
# 卷積層 -> 激活函數(shù) -> 最大池化
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
# 將卷積層輸出展平為一維
x = x.view(x.size(0), -1)
# 全連接層
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 保存模型權(quán)重
def save(self, path):
torch.save(self.state_dict(), path)
# 加載模型權(quán)重
def load(self, path):
self.load_state_dict(torch.load(path))
#改進(jìn)一下 output 把有棋子的地方的概率=0避免下這些地方
# 輸入Game對(duì)象和MyMod對(duì)象,用于得到概率最大的落棋點(diǎn) (行y, 列x)
def input_qi(g: Game, m: MyMod):
# 獲取當(dāng)前棋盤狀態(tài)
board_state = g.L # 使用 game.L 獲取當(dāng)前棋盤的狀態(tài) (15x15的二維列表)
# 將棋盤狀態(tài)轉(zhuǎn)換為PyTorch的Tensor并增加一個(gè)維度(batch_size = 1)
board_tensor = torch.tensor([[1 if cell == 'X' else -1 if cell == 'O' else 0 for cell in row] for row in board_state],
dtype=torch.float32).unsqueeze(0).unsqueeze(0) # 形狀變?yōu)?(1, 1, 15, 15)
# 傳入模型獲取每個(gè)位置的得分
output = m(board_tensor)
# 將輸出轉(zhuǎn)為概率值(可以使用softmax來(lái)歸一化)
probabilities = torch.softmax(output, dim=-1).view(g.h, g.w).detach().numpy() # 變?yōu)?(15, 15) 大小
# 將已有棋子的位置的概率設(shè)置為 -inf,避免選擇這些位置
for y in range(g.h):
for x in range(g.w):
if board_state[y][x] != '-':
probabilities[y, x] = -float('inf') # 設(shè)置已經(jīng)有棋子的地方的概率為 -inf
# 找到概率最大的落子點(diǎn)
max_prob_pos = divmod(probabilities.argmax(), g.w) # 得到最大概率的行列坐標(biāo)
# 確保返回的是合法的位置
y, x = max_prob_pos
return (y, x), output # 返回坐標(biāo)和模型輸出
xl.py
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from mod import MyMod, input_qi, Game
# 兩個(gè)權(quán)重文件,分別代表 X 棋和 O 棋
MX = 'MX'
MO = 'MO'
# 加載模型,若文件不存在則初始化
def load_model(model, path):
if os.path.exists(path):
model.load(path)
print(f"Loaded model from {path}")
else:
print(f"{path} not found, initializing new model.")
# 這里可以加一些初始化模型的代碼,例如:
# model.apply(init_weights) 如果需要初始化權(quán)重
# 初始化模型
modx = MyMod()
load_model(modx, MX)
modo = MyMod()
load_model(modo, MO)
# 定義優(yōu)化器
lr=0.001
optimizer_x = optim.Adam(modx.parameters(), lr=lr)
optimizer_o = optim.Adam(modo.parameters(), lr=lr)
# 損失函數(shù):根據(jù)游戲結(jié)果調(diào)整損失
def compute_loss(winner: int, player: str, model_output):
# 將目標(biāo)值轉(zhuǎn)換為相應(yīng)的張量
if player == "X":
if winner == 1: # X 勝
target = torch.tensor(1.0, dtype=torch.float32)
elif winner == 0: # 平局
target = torch.tensor(0.5, dtype=torch.float32)
else: # X 輸
target = torch.tensor(0.0, dtype=torch.float32)
else:
if winner == -1: # O 勝
target = torch.tensor(1.0, dtype=torch.float32)
elif winner == 0: # 平局
target = torch.tensor(0.5, dtype=torch.float32)
else: # O 輸
target = torch.tensor(0.0, dtype=torch.float32)
# 確保目標(biāo)值的形狀和 model_output 一致,假設(shè) model_output 是單一的值
target = target.unsqueeze(0).unsqueeze(0) # 形狀變?yōu)?(1, 1)
# 使用均方誤差損失計(jì)算
return F.mse_loss(model_output, target)
# 訓(xùn)練模型的過(guò)程
def train_game():
modx.train()
modo.train()
# 創(chuàng)建新的游戲?qū)嵗?
game = Game(15, 15) # 默認(rèn)是 15x15 棋盤
# 反向傳播和優(yōu)化
optimizer_x.zero_grad()
optimizer_o.zero_grad()
while not game.win_user: # 游戲未結(jié)束
# X 方落子
x_move, x_output = input_qi(game, modx) # 獲取落子位置和模型輸出(x_output 是模型的輸出)
game.set(x_move[0], x_move[1]) # X 下棋
if game.win_user:
break
# O 方落子
o_move, o_output = input_qi(game, modo) # 獲取落子位置和模型輸出(o_output 是模型的輸出)
#print(o_move,game)
game.set(o_move[0], o_move[1]) # O 下棋
# 獲取比賽結(jié)果
winner = 0 if game.heqi() else (1 if game.win_user == 'X' else -1) # 1為X勝,-1為O勝,0為平局
# 計(jì)算損失
loss_x = compute_loss(winner, "X", x_output) # 傳遞模型輸出給計(jì)算損失函數(shù)
loss_o = compute_loss(winner, "O", o_output) # 傳遞模型輸出給計(jì)算損失函數(shù)
# 計(jì)算損失并進(jìn)行反向傳播
loss_x.backward()
loss_o.backward()
# 更新權(quán)重
optimizer_x.step()
optimizer_o.step()
print(game)
return loss_x.item(), loss_o.item()
# 訓(xùn)練多個(gè)回合
def train(num_epochs,n):
k=0
for epoch in range(num_epochs):
loss_x, loss_o = train_game()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss X: {loss_x}, Loss O: {loss_o}")
k+=1
if k==n:
modo.save('MO')
modx.save('MX')
print('saved')
k=0
# 開(kāi)始訓(xùn)練
train(50000,1000)
aigame.py
from game import Game
from mod import MyMod,input_qi
#玩家下X ai下O
def playX():
m=MyMod()
m.load('MO')
g=Game(15,15)
while 1:
print(g)
if g.heqi() or g.win_user:
break
while 1:
try:
r=input('X:')
y,x=r.split(',')
y=int(y)
x=int(x)
if g.set(y,x):
break
except Exception as e:
print(e)
if g.heqi() or g.win_user:
break
while 1:
(y,x),_=input_qi(g,m)
if g.set(y,x):
break
print(g)
print('winner',g.win_user)
#玩家下O ai下X
def playO():
m=MyMod()
m.load('MX')
g=Game(15,15)
while 1:
if g.heqi() or g.win_user:
break
while 1:
(y,x),_=input_qi(g,m)
if g.set(y,x):
break
if g.heqi() or g.win_user:
break
print(g)
while 1:
try:
r=input('O:')
y,x=r.split(',')
y=int(y)
x=int(x)
if g.set(y,x):
break
except Exception as e:
print(e)
print(g)
print('winner',g.win_user)
playX()總結(jié)
到此這篇關(guān)于用pytorch訓(xùn)練五子棋ai的文章就介紹到這了,更多相關(guān)pytorch訓(xùn)練五子棋ai內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
詳解Python3操作Mongodb簡(jiǎn)明易懂教程
本篇文章主要介紹了詳解Python3操作Mongodb簡(jiǎn)明易懂教程,詳細(xì)的介紹了如何連接數(shù)據(jù)庫(kù)和對(duì)數(shù)據(jù)庫(kù)的操作,有需要的可以了解一下。2017-05-05
Python中聲明只包含一個(gè)元素的元組數(shù)據(jù)方法
這篇文章主要介紹了Python中聲明只包含一個(gè)元素的元組數(shù)據(jù)方法,本文是實(shí)際經(jīng)驗(yàn)總結(jié)而來(lái),沒(méi)有碰到這個(gè)需要可能不會(huì)注意到這個(gè)問(wèn)題,需要的朋友可以參考下2014-08-08
matplotlib 對(duì)坐標(biāo)的控制,加圖例注釋的操作
這篇文章主要介紹了matplotlib 對(duì)坐標(biāo)的控制,加圖例注釋的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-04-04
PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵?fù)p失
這篇文章主要介紹了PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵?fù)p失講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-02-02
Python?torch.onnx.export用法詳細(xì)介紹
這篇文章主要給大家介紹了關(guān)于Python?torch.onnx.export用法詳細(xì)介紹的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2022-07-07
Python實(shí)現(xiàn)批量讀取word中表格信息的方法
這篇文章主要介紹了Python實(shí)現(xiàn)批量讀取word中表格信息的方法,可實(shí)現(xiàn)針對(duì)word文檔的讀取功能,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-07-07
Python中subprocess介紹及如何使用詳細(xì)講解
在實(shí)際開(kāi)發(fā)過(guò)程中,我們經(jīng)常會(huì)遇到需要從Python腳本中調(diào)用外部程序或腳本的場(chǎng)景,下面這篇文章主要給大家介紹了關(guān)于Python中subprocess介紹及如何使用詳細(xì)講解的相關(guān)資料,需要的朋友可以參考下2024-09-09

