opencv3機器學習之EM算法示例詳解
引言
不同于其它的機器學習模型,EM算法是一種非監(jiān)督的學習算法,它的輸入數(shù)據(jù)事先不需要進行標注。相反,該算法從給定的樣本集中,能計算出高斯混和參數(shù)的最大似然估計。也能得到每個樣本對應的標注值,類似于kmeans聚類(輸入樣本數(shù)據(jù),輸出樣本數(shù)據(jù)的標注)。實際上,高斯混和模型GMM和kmeans都是EM算法的應用。
在opencv3.0中,EM算法的函數(shù)是trainEM,函數(shù)原型為:
bool trainEM(InputArray samples, OutputArray logLikelihoods=noArray(),OutputArray labels=noArray(),OutputArray probs=noArray())
四個參數(shù):
samples: 輸入的樣本,一個單通道的矩陣。從這個樣本中,進行高斯混和模型估計。
logLikelihoods: 可選項,輸出一個矩陣,里面包含每個樣本的似然對數(shù)值。
labels: 可選項,輸出每個樣本對應的標注。
probs: 可選項,輸出一個矩陣,里面包含每個隱性變量的后驗概率
這個函數(shù)沒有輸入?yún)?shù)的初始化值,是因為它會自動執(zhí)行kmeans算法,將kmeans算法得到的結果作為參數(shù)初始化。
這個trainEM函數(shù)實際把E步驟和M步驟都包含進去了,我們也可以對兩個步驟分開執(zhí)行,OPENCV3.0中也提供了分別執(zhí)行的函數(shù):
bool trainE(InputArray samples, InputArray means0,
InputArray covs0=noArray(),
InputArray weights0=noArray(),
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray())bool trainM(InputArray samples, InputArray probs0,
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray())trainEM函數(shù)的功能和kmeans差不多,都是實現(xiàn)自動聚類,輸出每個樣本對應的標注值。但它比kmeans還多出一個功能,就是它還能起到訓練分類器的作用,用于后續(xù)新樣本的預測。
預測函數(shù)原型為:
Vec2d predict2(InputArray sample, OutputArray probs) const
sample: 待測樣本
probs : 和上面一樣,一個可選的輸出值,包含每個隱性變量的后驗概率
返回一個Vec2d類型的數(shù),包括兩個元素的double向量,第一個元素為樣本的似然對數(shù)值,第二個元素為最大可能混和分量的索引值。
在本文中,我們用兩個實例來學習opencv中的EM算法的應用。
一、opencv3.0中自帶的例子
既包括聚類trianEM,也包括預測predict2
代碼:
#include "stdafx.h"
#include "opencv2/opencv.hpp"
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;
//使用EM算法實現(xiàn)樣本的聚類及預測
int main()
{
const int N = 4; //分成4類
const int N1 = (int)sqrt((double)N);
//定義四種顏色,每一類用一種顏色表示
const Scalar colors[] =
{
Scalar(0, 0, 255), Scalar(0, 255, 0),
Scalar(0, 255, 255), Scalar(255, 255, 0)
};
int i, j;
int nsamples = 100; //100個樣本點
Mat samples(nsamples, 2, CV_32FC1); //樣本矩陣,100行2列,即100個坐標點
Mat img = Mat::zeros(Size(500, 500), CV_8UC3); //待測數(shù)據(jù),每一個坐標點為一個待測數(shù)據(jù)
samples = samples.reshape(2, 0);
//循環(huán)生成四個類別樣本數(shù)據(jù),共樣本100個,每類樣本25個
for (i = 0; i < N; i++)
{
Mat samples_part = samples.rowRange(i*nsamples / N, (i + 1)*nsamples / N);
//設置均值
Scalar mean(((i%N1) + 1)*img.rows / (N1 + 1),
((i / N1) + 1)*img.rows / (N1 + 1));
//設置標準差
Scalar sigma(30, 30);
randn(samples_part, mean, sigma); //根據(jù)均值和標準差,隨機生成25個正態(tài)分布坐標點作為樣本
}
samples = samples.reshape(1, 0);
// 訓練分類器
Mat labels; //標注,不需要事先知道
Ptr<EM> em_model = EM::create();
em_model->setClustersNumber(N);
em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
em_model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 300, 0.1));
em_model->trainEM(samples, noArray(), labels, noArray());
//對每個坐標點進行分類,并根據(jù)類別用不同的顏色畫出
Mat sample(1, 2, CV_32FC1);
for (i = 0; i < img.rows; i++)
{
for (j = 0; j < img.cols; j++)
{
sample.at<float>(0) = (float)j;
sample.at<float>(1) = (float)i;
//predict2返回的是double值,用cvRound進行四舍五入得到整型
//此處返回的是兩個值Vec2d,取第二個值作為樣本標注
int response = cvRound(em_model->predict2(sample, noArray())[1]);
Scalar c = colors[response]; //為不同類別設定顏色
circle(img, Point(j, i), 1, c*0.75, FILLED);
}
}
//畫出樣本點
for (i = 0; i < nsamples; i++)
{
Point pt(cvRound(samples.at<float>(i, 0)), cvRound(samples.at<float>(i, 1)));
circle(img, pt, 2, colors[labels.at<int>(i)], FILLED);
}
imshow("EM聚類結果", img);
waitKey(0);
return 0;
}結果:

二、trainEM實現(xiàn)自動聚類進行圖片目標檢測
只用trainEM實現(xiàn)自動聚類功能,進行圖片中的目標檢測
代碼:
#include "stdafx.h"
#include "opencv2/opencv.hpp"
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;
int main()
{
const int MAX_CLUSTERS = 5;
Vec3b colorTab[] =
{
Vec3b(0, 0, 255),
Vec3b(0, 255, 0),
Vec3b(255, 100, 100),
Vec3b(255, 0, 255),
Vec3b(0, 255, 255)
};
Mat data, labels;
Mat pic = imread("d:/woman.png");
for (int i = 0; i < pic.rows; i++)
for (int j = 0; j < pic.cols; j++)
{
Vec3b point = pic.at<Vec3b>(i, j);
Mat tmp = (Mat_<float>(1, 3) << point[0], point[1], point[2]);
data.push_back(tmp);
}
int N =3; //聚成3類
Ptr<EM> em_model = EM::create();
em_model->setClustersNumber(N);
em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
em_model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 300, 0.1));
em_model->trainEM(data, noArray(), labels, noArray());
int n = 0;
//顯示聚類結果,不同的類別用不同的顏色顯示
for (int i = 0; i < pic.rows; i++)
for (int j = 0; j < pic.cols; j++)
{
int clusterIdx = labels.at<int>(n);
pic.at<Vec3b>(i, j) = colorTab[clusterIdx];
n++;
}
imshow("pic", pic);
waitKey(0);
return 0;
}測試圖片

測試結果:

以上就是opencv3機器學習之EM算法的詳細內(nèi)容,更多關于opencv3 EM算法的資料請關注腳本之家其它相關文章!
相關文章
Visual Studio 2019 DLL動態(tài)庫連接實例(圖文教程)
這篇文章主要介紹了Visual Studio 2019 DLL動態(tài)庫連接實例,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-03-03
VC++植物大戰(zhàn)僵尸中文版修改器實現(xiàn)代碼
這篇文章主要介紹了VC++植物大戰(zhàn)僵尸中文版修改器實現(xiàn)代碼,可實現(xiàn)植物大戰(zhàn)僵尸中的無限陽光與無冷卻時間功能,需要的朋友可以參考下2015-04-04
Linux C/C++實現(xiàn)DNS客戶端請求域名IP的示例代碼
DNS全稱:Domain Name System,域名解析系統(tǒng),是互聯(lián)網(wǎng)的一項服務,本文主要介紹了C/C++如何實現(xiàn)DNS客戶端請求域名IP,感興趣的可以了解下2024-03-03

