Pytorch的mean和std調查實例
更新時間:2020年01月02日 10:09:43 作者:機器學習的小學生
今天小編就為大家分享一篇Pytorch的mean和std調查實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
如下所示:
# coding: utf-8
from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms
import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio
# if model has LSTM
# torch.backends.cudnn.enabled = False
imgpath = 'D:/ck/files_detected_face224/'
imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname
mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=mean_file,
std =std_file,
#mean = mean_file,
#std = std_file,
)
])(raw_image).unsqueeze(0)
print(image.shape)
convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)
convert_image1 = convert_image1 * 255
diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()
結論:
input_image = (raw_image / 255 - mean) ./ std
下面調查均值文件和方差文件是如何生成的:
mean_file = [0.485, 0.456, 0.406] std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms
dataset_names = ('cifar10','cifar100','mnist')
parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')
args = parser.parse_args()
data_dir = os.path.join('.', args.dataset)
print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
train_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
#print(vars(train_set))
print(train_set.train_data.shape)
print(train_set.train_data.mean(axis=(0,1,2))/255)
print(train_set.train_data.std(axis=(0,1,2))/255)
# imshow image
train_data = train_set.train_data
ind = 100
img0 = train_data[ind,...]
## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
# error produce
#b,g,r=cv2.split(img0)
#img0=cv2.merge([r,g,b])
print(img0.shape)
print(type(img0))
plt.imshow(img0)
plt.show() # in ship in sea
#img0 = cv2.resize(img0,(224,224))
#cv2.imshow('img0',img0)
#cv2.waitKey()
elif args.dataset == "cifar100":
train_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
#print(vars(train_set))
print(train_set.train_data.shape)
print(np.mean(train_set.train_data, axis=(0,1,2))/255)
print(np.std(train_set.train_data, axis=(0,1,2))/255)
elif args.dataset == "mnist":
train_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
#print(vars(train_set))
print(list(train_set.train_data.size()))
print(train_set.train_data.float().mean()/255)
print(train_set.train_data.float().std()/255)
結果:
cifar10 Files already downloaded and verified (50000, 32, 32, 3) [ 0.49139968 0.48215841 0.44653091] [ 0.24703223 0.24348513 0.26158784] (32, 32, 3) <class 'numpy.ndarray'>
使用matlab檢測是如何計算mean_file和std_file的:
% load cifar10 dataset
data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));
temp = mean(train_data,1);
disp(size(temp));
train_data = double(train_data);
% compute mean_file
mean_val = mean(mean(mean(train_data,1),2),3)/255;
% compute std_file
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;
temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;
temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;
mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];
disp(mean_val);
disp(std_val);
% result: mean_val: [0.4914, 0.4822, 0.4465]
% std_val: [0.2470, 0.2435, 0.2616]
均值計算的過程也可以遵循標準差的計算過程。為 了簡單,例如對于一個矩陣,所有元素的均值,等于兩個方向上先后均值。所以會直接采用如下的形式:
mean_val = mean(mean(mean(train_data,1),2),3)/255;
標準差的計算是每一個通道的對所有樣本的求標準差。然后再除以255。
以上這篇Pytorch的mean和std調查實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
13行python代碼實現(xiàn)對微信進行推送消息的示例代碼
本文主要介紹了13行python代碼實現(xiàn)對微信進行推送消息的示例代碼,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2022-08-08
Python還能這么玩之只用30行代碼從excel提取個人值班表
公司實行項目值班制度,拿到值班表,看到全部的值班信息,要去查找自己的值班信息,是一件頭痛的事情.作為程序員,當然要簡化,將自己的信息提煉出來,需要的朋友可以參考下2021-06-06
Pandas實現(xiàn)groupby分組統(tǒng)計方法實例
在數(shù)據(jù)處理的過程,有可能需要對一堆數(shù)據(jù)分組處理,例如對不同的列進行agg聚合操作(mean,min,max等等),下面這篇文章主要給大家介紹了關于Pandas實現(xiàn)groupby分組統(tǒng)計方法的相關資料,需要的朋友可以參考下2023-06-06

