教你用pytorch訓(xùn)練五子棋ai示例代碼
有4個文件
game.py 五子棋游戲
mod.py 神經(jīng)網(wǎng)絡(luò)模型
xl.py 訓(xùn)練的代碼
aigame.py 玩家與對戰(zhàn)的五子棋
game.py
class Game: def __init__(self, h, w): # 行數(shù) self.h = h # 列數(shù) self.w = w # 棋盤 self.L = [['-' for _ in range(w)] for _ in range(h)] # 當(dāng)前玩家 - 表示空 X先下 然后是O self.cur = 'X' # 游戲勝利者 self.win_user = None # 檢查下完這步后有沒有贏 y是行 x是列 返回True表示贏 def check_win(self, y, x): directions = [ # 水平、垂直、兩個對角線方向 (1, 0), (0, 1), (1, 1), (1, -1) ] player = self.L[y][x] for dy, dx in directions: count = 0 # 檢查四個方向上的連續(xù)相同棋子 for i in range(-4, 5): # 檢查-4到4的范圍,因為五子連珠需要5個棋子 ny, nx = y + i * dy, x + i * dx if 0 <= ny < self.h and 0 <= nx < self.w and self.L[ny][nx] == player: count += 1 if count == 5: return True else: count = 0 return False # 檢查能不能下這里 y行 x列 返回True表示能下 def check(self, y, x): return self.L[y][x] == '-' and self.win_user is None # 打印棋盤 可視化用得到 def __str__(self): # 確定行號和列號的寬度 row_width = len(str(self.h - 1)) col_width = len(str(self.w - 1)) # 生成帶有行號和列號的棋盤字符串表示 result = [] # 添加列號標(biāo)題 result.append(' ' * (row_width + 1) + ' '.join(f'{i:>{col_width}}' for i in range(self.w))) # 添加分隔線(可選) result.append(' ' * (row_width + 1) + '-' * (col_width * self.w)) # 添加棋盤行 for y, row in enumerate(self.L): # 添加行號 result.append(f'{y:>{row_width}} ' + ' '.join(f'{cell:>{col_width}}' for cell in row)) return '\n'.join(result) # 一步棋 def set(self, y, x): if self.win_user or not self.check(y, x): return False self.L[y][x] = self.cur if self.check_win(y, x): self.win_user = self.cur return True self.cur = 'X' if self.cur == 'O' else 'O' return True #和棋 def heqi(self): for y in range(self.h): for x in range(self.w): if self.L[y][x]=='-': return False return True #玩家自己下 def run_game01(): g = Game(15, 15) while not g.win_user: # 打印當(dāng)前棋盤狀態(tài) while 1: print(g) try: y,x=input(g.cur+':').split(',') x=int(x) y=int(y) if g.set(y,x): break except Exception as e: print(e) print(g) print('勝利者',g.win_user)
mod.py
import torch import torch.nn as nn import torch.optim as optim from game import Game class MyMod(nn.Module): def __init__(self, input_channels=1, output_size=15*15): super(MyMod, self).__init__() # 定義卷積層,用于提取特征 self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) # 輸出 32 x 15 x 15 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 輸出 64 x 15 x 15 self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 輸出 128 x 15 x 15 # 定義全連接層,用于最后的得分預(yù)測 self.fc1 = nn.Linear(128 * 15 * 15, 1024) # 展平后傳入全連接層 self.fc2 = nn.Linear(1024, output_size) # 輸出 15*15 的得分預(yù)測 def forward(self, x): # 卷積層 -> 激活函數(shù) -> 最大池化 x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = torch.relu(self.conv3(x)) # 將卷積層輸出展平為一維 x = x.view(x.size(0), -1) # 全連接層 x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # 保存模型權(quán)重 def save(self, path): torch.save(self.state_dict(), path) # 加載模型權(quán)重 def load(self, path): self.load_state_dict(torch.load(path)) #改進一下 output 把有棋子的地方的概率=0避免下這些地方 # 輸入Game對象和MyMod對象,用于得到概率最大的落棋點 (行y, 列x) def input_qi(g: Game, m: MyMod): # 獲取當(dāng)前棋盤狀態(tài) board_state = g.L # 使用 game.L 獲取當(dāng)前棋盤的狀態(tài) (15x15的二維列表) # 將棋盤狀態(tài)轉(zhuǎn)換為PyTorch的Tensor并增加一個維度(batch_size = 1) board_tensor = torch.tensor([[1 if cell == 'X' else -1 if cell == 'O' else 0 for cell in row] for row in board_state], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # 形狀變?yōu)?(1, 1, 15, 15) # 傳入模型獲取每個位置的得分 output = m(board_tensor) # 將輸出轉(zhuǎn)為概率值(可以使用softmax來歸一化) probabilities = torch.softmax(output, dim=-1).view(g.h, g.w).detach().numpy() # 變?yōu)?(15, 15) 大小 # 將已有棋子的位置的概率設(shè)置為 -inf,避免選擇這些位置 for y in range(g.h): for x in range(g.w): if board_state[y][x] != '-': probabilities[y, x] = -float('inf') # 設(shè)置已經(jīng)有棋子的地方的概率為 -inf # 找到概率最大的落子點 max_prob_pos = divmod(probabilities.argmax(), g.w) # 得到最大概率的行列坐標(biāo) # 確保返回的是合法的位置 y, x = max_prob_pos return (y, x), output # 返回坐標(biāo)和模型輸出
xl.py
import os import torch import torch.optim as optim import torch.nn.functional as F from mod import MyMod, input_qi, Game # 兩個權(quán)重文件,分別代表 X 棋和 O 棋 MX = 'MX' MO = 'MO' # 加載模型,若文件不存在則初始化 def load_model(model, path): if os.path.exists(path): model.load(path) print(f"Loaded model from {path}") else: print(f"{path} not found, initializing new model.") # 這里可以加一些初始化模型的代碼,例如: # model.apply(init_weights) 如果需要初始化權(quán)重 # 初始化模型 modx = MyMod() load_model(modx, MX) modo = MyMod() load_model(modo, MO) # 定義優(yōu)化器 lr=0.001 optimizer_x = optim.Adam(modx.parameters(), lr=lr) optimizer_o = optim.Adam(modo.parameters(), lr=lr) # 損失函數(shù):根據(jù)游戲結(jié)果調(diào)整損失 def compute_loss(winner: int, player: str, model_output): # 將目標(biāo)值轉(zhuǎn)換為相應(yīng)的張量 if player == "X": if winner == 1: # X 勝 target = torch.tensor(1.0, dtype=torch.float32) elif winner == 0: # 平局 target = torch.tensor(0.5, dtype=torch.float32) else: # X 輸 target = torch.tensor(0.0, dtype=torch.float32) else: if winner == -1: # O 勝 target = torch.tensor(1.0, dtype=torch.float32) elif winner == 0: # 平局 target = torch.tensor(0.5, dtype=torch.float32) else: # O 輸 target = torch.tensor(0.0, dtype=torch.float32) # 確保目標(biāo)值的形狀和 model_output 一致,假設(shè) model_output 是單一的值 target = target.unsqueeze(0).unsqueeze(0) # 形狀變?yōu)?(1, 1) # 使用均方誤差損失計算 return F.mse_loss(model_output, target) # 訓(xùn)練模型的過程 def train_game(): modx.train() modo.train() # 創(chuàng)建新的游戲?qū)嵗? game = Game(15, 15) # 默認是 15x15 棋盤 # 反向傳播和優(yōu)化 optimizer_x.zero_grad() optimizer_o.zero_grad() while not game.win_user: # 游戲未結(jié)束 # X 方落子 x_move, x_output = input_qi(game, modx) # 獲取落子位置和模型輸出(x_output 是模型的輸出) game.set(x_move[0], x_move[1]) # X 下棋 if game.win_user: break # O 方落子 o_move, o_output = input_qi(game, modo) # 獲取落子位置和模型輸出(o_output 是模型的輸出) #print(o_move,game) game.set(o_move[0], o_move[1]) # O 下棋 # 獲取比賽結(jié)果 winner = 0 if game.heqi() else (1 if game.win_user == 'X' else -1) # 1為X勝,-1為O勝,0為平局 # 計算損失 loss_x = compute_loss(winner, "X", x_output) # 傳遞模型輸出給計算損失函數(shù) loss_o = compute_loss(winner, "O", o_output) # 傳遞模型輸出給計算損失函數(shù) # 計算損失并進行反向傳播 loss_x.backward() loss_o.backward() # 更新權(quán)重 optimizer_x.step() optimizer_o.step() print(game) return loss_x.item(), loss_o.item() # 訓(xùn)練多個回合 def train(num_epochs,n): k=0 for epoch in range(num_epochs): loss_x, loss_o = train_game() print(f"Epoch [{epoch+1}/{num_epochs}], Loss X: {loss_x}, Loss O: {loss_o}") k+=1 if k==n: modo.save('MO') modx.save('MX') print('saved') k=0 # 開始訓(xùn)練 train(50000,1000)
aigame.py
from game import Game from mod import MyMod,input_qi #玩家下X ai下O def playX(): m=MyMod() m.load('MO') g=Game(15,15) while 1: print(g) if g.heqi() or g.win_user: break while 1: try: r=input('X:') y,x=r.split(',') y=int(y) x=int(x) if g.set(y,x): break except Exception as e: print(e) if g.heqi() or g.win_user: break while 1: (y,x),_=input_qi(g,m) if g.set(y,x): break print(g) print('winner',g.win_user) #玩家下O ai下X def playO(): m=MyMod() m.load('MX') g=Game(15,15) while 1: if g.heqi() or g.win_user: break while 1: (y,x),_=input_qi(g,m) if g.set(y,x): break if g.heqi() or g.win_user: break print(g) while 1: try: r=input('O:') y,x=r.split(',') y=int(y) x=int(x) if g.set(y,x): break except Exception as e: print(e) print(g) print('winner',g.win_user) playX()
總結(jié)
到此這篇關(guān)于用pytorch訓(xùn)練五子棋ai的文章就介紹到這了,更多相關(guān)pytorch訓(xùn)練五子棋ai內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中聲明只包含一個元素的元組數(shù)據(jù)方法
這篇文章主要介紹了Python中聲明只包含一個元素的元組數(shù)據(jù)方法,本文是實際經(jīng)驗總結(jié)而來,沒有碰到這個需要可能不會注意到這個問題,需要的朋友可以參考下2014-08-08matplotlib 對坐標(biāo)的控制,加圖例注釋的操作
這篇文章主要介紹了matplotlib 對坐標(biāo)的控制,加圖例注釋的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵損失
這篇文章主要介紹了PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵損失講解,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-02-02Python?torch.onnx.export用法詳細介紹
這篇文章主要給大家介紹了關(guān)于Python?torch.onnx.export用法詳細介紹的相關(guān)資料,文中通過實例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2022-07-07Python實現(xiàn)批量讀取word中表格信息的方法
這篇文章主要介紹了Python實現(xiàn)批量讀取word中表格信息的方法,可實現(xiàn)針對word文檔的讀取功能,具有一定參考借鑒價值,需要的朋友可以參考下2015-07-07