pytorch如何使用Imagenet預(yù)訓(xùn)練模型訓(xùn)練
pytorch使用Imagenet預(yù)訓(xùn)練模型訓(xùn)練
1、loading models
#加載以resnet50為例子 import torchvision as p model = p.models.resnet50(pretrained=True)
此時(shí)加載數(shù)據(jù)模型以后,我們要是思考如何利用它,但是在此之前你必須了解你加載的模型的結(jié)構(gòu)。
2、處理分類數(shù)據(jù)
如果是用來處理分類數(shù)據(jù):
你只需要替換最后一個全連接分類進(jìn)行輸出。
model.fc = nn.Sequential(nn.Linear(2048,num_classes)) ######
3、作為模型的backbone
如果你需要作為要做模型的bacbone,比如RCNN、Semantic Segment等,此時(shí)你要將這些模型預(yù)加載進(jìn)行來,以下面的一個FCN8-語義切割為例子:
這里的model就是之前Resnet50model that has pretrained Imageset dataset
class FCN(nn.Module):
def __init__(self):
super(FCN,self).__init__()
self.layer1 = nn.Conv2d(256,nClasses,1,stride=1,padding=0,bias=True)
self.trans = nn.ConvTranspose2d(nClasses,nClasses,2,stride=2,padding=0,bias=True)
self.layer2 = nn.Conv2d(128,nClasses,1,stride=1,padding=0,bias=True)
self.up = nn.ConvTranspose2d(nClasses,nClasses,8,stride=8,padding=0,bias=True)
for m in self.modules():
if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
#m.weight.detach().normal_(0,0.01)
nn.init.xavier_uniform(m.weight.data)
m.bias.detach().zero_()
def forward(self,x,model):
x = model.conv1(x)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)
x = model.layer1(x)
x1 = model.layer2(x)
x2 = model.layer3(x1)
#layers.append(x)#20
x = model.layer4(x2)
x = model.avgpool(x)#20
skip = self.layer1(x2)
y = skip + x
c = self.trans(y)
#### 40
v = self.layer2(x1)
y = c+v
x = self.up(y)
return x當(dāng)然還有其他寫法,比如直接類的構(gòu)造函數(shù)里面,你先取出來后面也是非常簡單了:
values = []
for m in model.modules():
values.append(m)
#nn.Sequential()PyTorch ImageNet示例
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
best_prec1 = 0
def main():
global args, best_prec1
args = parser.parse_args()
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion)
return
for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
# evaluate on validation set
prec1 = validate(val_loader, model, criterion)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True)
target_var = torch.autograd.Variable(target, volatile=True)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()總結(jié)
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python實(shí)現(xiàn)動態(tài)GIF英數(shù)驗(yàn)證碼識別示例
這篇文章主要為大家介紹了python實(shí)現(xiàn)動態(tài)GIF英數(shù)驗(yàn)證碼識別示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01
Python向Excel中插入圖片的簡單實(shí)現(xiàn)方法
這篇文章主要介紹了Python向Excel中插入圖片的簡單實(shí)現(xiàn)方法,結(jié)合實(shí)例形式分析了Python使用XlsxWriter模塊操作Excel單元格插入jpg格式圖片的相關(guān)操作技巧,非常簡單實(shí)用,需要的朋友可以參考下2018-04-04
Python使用scrapy采集數(shù)據(jù)過程中放回下載過大頁面的方法
這篇文章主要介紹了Python使用scrapy采集數(shù)據(jù)過程中放回下載過大頁面的方法,可實(shí)現(xiàn)限制下載過大頁面的功能,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-04-04
python實(shí)現(xiàn)微信自動回復(fù)機(jī)器人功能
wxpy基于itchat,使用了 Web 微信的通訊協(xié)議,通過大量接口優(yōu)化提升了模塊的易用性,并進(jìn)行豐富的功能擴(kuò)展。這篇文章主要介紹了python實(shí)現(xiàn)微信自動回復(fù)機(jī)器人功能,需要的朋友可以參考下2019-07-07
解決Django Static內(nèi)容不能加載顯示的問題
今天小編就為大家分享一篇解決Django Static內(nèi)容不能加載顯示的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07
Python實(shí)現(xiàn)將一段話txt生成字幕srt文件
這篇文章主要為大家詳細(xì)介紹了如何利用Python實(shí)現(xiàn)將一段話txt生成字幕srt文件,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-02-02

