pytorch如何使用Imagenet預(yù)訓(xùn)練模型訓(xùn)練
pytorch使用Imagenet預(yù)訓(xùn)練模型訓(xùn)練
1、loading models
#加載以resnet50為例子 import torchvision as p model = p.models.resnet50(pretrained=True)
此時加載數(shù)據(jù)模型以后,我們要是思考如何利用它,但是在此之前你必須了解你加載的模型的結(jié)構(gòu)。
2、處理分類數(shù)據(jù)
如果是用來處理分類數(shù)據(jù):
你只需要替換最后一個全連接分類進(jìn)行輸出。
model.fc = nn.Sequential(nn.Linear(2048,num_classes)) ######
3、作為模型的backbone
如果你需要作為要做模型的bacbone,比如RCNN、Semantic Segment等,此時你要將這些模型預(yù)加載進(jìn)行來,以下面的一個FCN8-語義切割為例子:
這里的model就是之前Resnet50model that has pretrained Imageset dataset
class FCN(nn.Module): def __init__(self): super(FCN,self).__init__() self.layer1 = nn.Conv2d(256,nClasses,1,stride=1,padding=0,bias=True) self.trans = nn.ConvTranspose2d(nClasses,nClasses,2,stride=2,padding=0,bias=True) self.layer2 = nn.Conv2d(128,nClasses,1,stride=1,padding=0,bias=True) self.up = nn.ConvTranspose2d(nClasses,nClasses,8,stride=8,padding=0,bias=True) for m in self.modules(): if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d): #m.weight.detach().normal_(0,0.01) nn.init.xavier_uniform(m.weight.data) m.bias.detach().zero_() def forward(self,x,model): x = model.conv1(x) x = model.bn1(x) x = model.relu(x) x = model.maxpool(x) x = model.layer1(x) x1 = model.layer2(x) x2 = model.layer3(x1) #layers.append(x)#20 x = model.layer4(x2) x = model.avgpool(x)#20 skip = self.layer1(x2) y = skip + x c = self.trans(y) #### 40 v = self.layer2(x1) y = c+v x = self.up(y) return x
當(dāng)然還有其他寫法,比如直接類的構(gòu)造函數(shù)里面,你先取出來后面也是非常簡單了:
values = [] for m in model.modules(): values.append(m) #nn.Sequential()
PyTorch ImageNet示例
import argparse import os import shutil import time import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') best_prec1 = 0 def main(): global args, best_prec1 args = parser.parse_args() # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader( datasets.ImageFolder(traindir, transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best) def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) target = target.cuda(async=True) input_var = torch.autograd.Variable(input) target_var = torch.autograd.Variable(target) # compute output output = model(input_var) loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.data[0], input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(val_loader, model, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() for i, (input, target) in enumerate(val_loader): target = target.cuda(async=True) input_var = torch.autograd.Variable(input, volatile=True) target_var = torch.autograd.Variable(target, volatile=True) # compute output output = model(input_var) loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.data[0], input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best.pth.tar') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res if __name__ == '__main__': main()
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python實現(xiàn)動態(tài)GIF英數(shù)驗證碼識別示例
這篇文章主要為大家介紹了python實現(xiàn)動態(tài)GIF英數(shù)驗證碼識別示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01Python向Excel中插入圖片的簡單實現(xiàn)方法
這篇文章主要介紹了Python向Excel中插入圖片的簡單實現(xiàn)方法,結(jié)合實例形式分析了Python使用XlsxWriter模塊操作Excel單元格插入jpg格式圖片的相關(guān)操作技巧,非常簡單實用,需要的朋友可以參考下2018-04-04Python使用scrapy采集數(shù)據(jù)過程中放回下載過大頁面的方法
這篇文章主要介紹了Python使用scrapy采集數(shù)據(jù)過程中放回下載過大頁面的方法,可實現(xiàn)限制下載過大頁面的功能,非常具有實用價值,需要的朋友可以參考下2015-04-04解決Django Static內(nèi)容不能加載顯示的問題
今天小編就為大家分享一篇解決Django Static內(nèi)容不能加載顯示的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07Python實現(xiàn)將一段話txt生成字幕srt文件
這篇文章主要為大家詳細(xì)介紹了如何利用Python實現(xiàn)將一段話txt生成字幕srt文件,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-02-02