從Pytorch模型pth文件中讀取參數(shù)成numpy矩陣的操作
目的:
把訓(xùn)練好的pth模型參數(shù)提取出來,然后用其他方式部署到邊緣設(shè)備。
Pytorch給了很方便的讀取參數(shù)接口:
nn.Module.parameters()
直接看demo:
from torchvision.models.alexnet import alexnet model = alexnet(pretrained=True).eval().cuda() parameters = model.parameters() for p in parameters: numpy_para = p.detach().cpu().numpy() print(type(numpy_para)) print(numpy_para.shape)
上面得到的numpy_para就是numpy參數(shù)了~
Note:
model.parameters()是以一個生成器的形式迭代返回每一層的參數(shù)。所以用for循環(huán)讀取到各層的參數(shù),循環(huán)次數(shù)就表示層數(shù)。
而每一層的參數(shù)都是torch.nn.parameter.Parameter類型,是Tensor的子類,所以直接用tensor轉(zhuǎn)numpy(即p.detach().cpu().numpy())的方法就可以直接轉(zhuǎn)成numpy矩陣。
方便又好用,爆贊~
補充:pytorch訓(xùn)練好的.pth模型轉(zhuǎn)換為.pt
將python訓(xùn)練好的.pth文件轉(zhuǎn)為.pt
import torch import torchvision from unet import UNet model = UNet(3, 2)#自己定義的網(wǎng)絡(luò)模型 model.load_state_dict(torch.load("best_weights.pth"))#保存的訓(xùn)練模型 model.eval()#切換到eval() example = torch.rand(1, 3, 320, 480)#生成一個隨機輸入維度的輸入 traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。
相關(guān)文章
淺談python中常用的8種經(jīng)典數(shù)據(jù)結(jié)構(gòu)
這篇文章主要介紹了python中常用的8種經(jīng)典數(shù)據(jù)結(jié)構(gòu),包括原生數(shù)據(jù)結(jié)構(gòu),NumPy包中的數(shù)據(jù)結(jié)構(gòu),以及Pandas包中的數(shù)據(jù)結(jié)構(gòu),需要的朋友可以參考下2023-03-03python 如何將office文件轉(zhuǎn)換為PDF
這篇文章主要介紹了python 如何將office文件轉(zhuǎn)換為PDF,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2020-09-09django 解決自定義序列化返回處理數(shù)據(jù)為null的問題
這篇文章主要介紹了django 解決自定義序列化返回處理數(shù)據(jù)為null的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05