Python自定義指標(biāo)聚類實(shí)例代碼
前言
最近在研究 Yolov2 論文的時(shí)候,發(fā)現(xiàn)作者在做先驗(yàn)框聚類使用的指標(biāo)并非歐式距離,而是IOU。在找了很多資料之后,基本確定 Python 沒(méi)有自定義指標(biāo)聚類的函數(shù),所以打算自己做一個(gè)
設(shè)訓(xùn)練集的 shape 是 [n_sample, n_feature],基本思路是:
- 簇中心初始化:第 1 個(gè)簇中心取樣本的特征均值,shape = [n_feature, ];從第 2 個(gè)簇中心開(kāi)始,用距離函數(shù) (自定義) 計(jì)算每個(gè)樣本到最近中心點(diǎn)的距離,歸一化后作為選取下一個(gè)簇中心的概率 —— 迭代到選取到足夠的簇中心為止
- 簇中心調(diào)整:訓(xùn)練多輪,每一輪以樣本點(diǎn)到最近中心點(diǎn)的距離之和作為 loss,梯度下降法 + Adam 優(yōu)化器逼近最優(yōu)解,在 loss 浮動(dòng)值小于閾值的次數(shù)達(dá)到一定值時(shí)停止訓(xùn)練
因?yàn)樵O(shè)計(jì)之初就打算使用自定義距離函數(shù),所以求導(dǎo)是很大的難題。筆者不才,最終決定借助 PyTorch 自動(dòng)求導(dǎo)的天然優(yōu)勢(shì)
先給出歐式距離的計(jì)算函數(shù)
def Eu_dist(data, center):
""" 以 歐氏距離 為聚類準(zhǔn)則的距離計(jì)算函數(shù)
data: 形如 [n_sample, n_feature] 的 tensor
center: 形如 [n_cluster, n_feature] 的 tensor"""
data = data.unsqueeze(1)
center = center.unsqueeze(0)
dist = ((data - center) ** 2).sum(dim=2)
return dist然后就是聚類器的代碼:使用時(shí)只需關(guān)注 __init__、fit、classify 函數(shù)
import torch
import numpy as np
import matplotlib.pyplot as plt
Adam = torch.optim.Adam
def get_progress(current, target, bar_len=30):
""" current: 當(dāng)前完成任務(wù)數(shù)
target: 任務(wù)總數(shù)
bar_len: 進(jìn)度條長(zhǎng)度
return: 進(jìn)度條字符串"""
assert current <= target
percent = round(current / target * 100, 1)
unit = 100 / bar_len
solid = int(percent / unit)
hollow = bar_len - solid
return "■" * solid + "□" * hollow + f" {current}/{target}({percent}%)"
class Cluster:
""" 聚類器
n_cluster: 簇中心數(shù)
dist_fun: 距離計(jì)算函數(shù)
kwargs:
data: 形如 [n_sample, n_feather] 的 tensor
center: 形如 [n_cluster, n_feature] 的 tensor
return: 形如 [n_sample, n_cluster] 的 tensor
init: 初始簇中心
max_iter: 最大迭代輪數(shù)
lr: 中心點(diǎn)坐標(biāo)學(xué)習(xí)率
stop_thresh: 停止訓(xùn)練的loss浮動(dòng)閾值
cluster_centers_: 聚類中心
labels_: 聚類結(jié)果"""
def __init__(self, n_cluster, dist_fun, init=None, max_iter=300, lr=0.08, stop_thresh=1e-4):
self._n_cluster = n_cluster
self._dist_fun = dist_fun
self._max_iter = max_iter
self._lr = lr
self._stop_thresh = stop_thresh
# 初始化參數(shù)
self.cluster_centers_ = None if init is None else torch.FloatTensor(init)
self.labels_ = None
self._bar_len = 20
def fit(self, data):
""" data: 形如 [n_sample, n_feature] 的 tensor
return: loss浮動(dòng)日志"""
if self.cluster_centers_ is None:
self._init_cluster(data, self._max_iter // 5)
log = self._train(data, self._max_iter, self._lr)
# 開(kāi)始若干輪次的訓(xùn)練,得到loss浮動(dòng)日志
return log
def classify(self, data, show=False):
""" data: 形如 [n_sample, n_feature] 的 tensor
show: 繪制分類結(jié)果
return: 分類標(biāo)簽"""
dist = self._dist_fun(data, self.cluster_centers_)
self.labels_ = dist.argmin(axis=1)
# 將標(biāo)簽加載到實(shí)例屬性
if show:
for idx in range(self._n_cluster):
container = data[self.labels_ == idx]
plt.scatter(container[:, 0], container[:, 1], alpha=0.7)
plt.scatter(self.cluster_centers_[:, 0], self.cluster_centers_[:, 1], c="gold", marker="p", s=50)
plt.show()
return self.labels_
def _init_cluster(self, data, epochs):
self.cluster_centers_ = data.mean(dim=0).reshape(1, -1)
for idx in range(1, self._n_cluster):
dist = np.array(self._dist_fun(data, self.cluster_centers_).min(dim=1)[0])
new_cluster = data[np.random.choice(range(data.shape[0]), p=dist / dist.sum())].reshape(1, -1)
# 取新的中心點(diǎn)
self.cluster_centers_ = torch.cat([self.cluster_centers_, new_cluster], dim=0)
progress = get_progress(idx, self._n_cluster, bar_len=self._n_cluster if self._n_cluster <= self._bar_len else self._bar_len)
print(f"\rCluster Init: {progress}", end="")
self._train(data, epochs, self._lr * 2.5, init=True)
# 初始化簇中心時(shí)使用較大的lr
def _train(self, data, epochs, lr, init=False):
center = self.cluster_centers_.cuda()
center.requires_grad = True
data = data.cuda()
optimizer = Adam([center], lr=lr)
# 將中心數(shù)據(jù)加載到 GPU 上
init_patience = int(epochs ** 0.5)
patience = init_patience
update_log = []
min_loss = np.inf
for epoch in range(epochs):
# 對(duì)樣本分類并更新中心點(diǎn)
sample_dist = self._dist_fun(data, center).min(dim=1)
self.labels_ = sample_dist[1]
loss = sum([sample_dist[0][self.labels_ == idx].mean() for idx in range(len(center))])
# loss 函數(shù): 所有樣本到中心點(diǎn)的最小距離和 - 中心點(diǎn)間的最小間隔
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 反向傳播梯度更新中心點(diǎn)
loss = loss.item()
progress = min_loss - loss
update_log.append(progress)
if progress > 0:
self.cluster_centers_ = center.cpu().detach()
min_loss = loss
# 脫離計(jì)算圖后記錄中心點(diǎn)
if progress < self._stop_thresh:
patience -= 1
# 耐心值減少
if patience < 0:
break
# 耐心值歸零時(shí)退出
else:
patience = init_patience
# 恢復(fù)耐心值
progress = get_progress(init_patience - patience, init_patience, bar_len=self._bar_len)
if not init:
print(f"\rCluster: {progress}\titer: {epoch + 1}", end="")
if not init:
print("")
return torch.FloatTensor(update_log)與KMeans++比較
KMeans++ 是以歐式距離為聚類準(zhǔn)則的經(jīng)典聚類算法。在 iris 數(shù)據(jù)集上,KMeans++ 遠(yuǎn)遠(yuǎn)快于我的聚類器。但在我反復(fù)對(duì)比測(cè)試的幾輪里,我的聚類器精度也是不差的 —— 可以看到下圖里的聚類結(jié)果完全一致

| KMeans++ | My Cluster | |
| Cost | 145 ms | 1597 ms |
| Center | [[5.9016, 2.7484, 4.3935, 1.4339], [5.0060, 3.4280, 1.4620, 0.2460], | [[5.9016, 2.7485, 4.3934, 1.4338], |
雖然速度方面與老牌算法對(duì)比的確不行,但是我的這個(gè)聚類器最大的亮點(diǎn)還是自定義距離函數(shù)
Yolo 檢測(cè)框聚類
本來(lái)想用 Yolov4 檢測(cè)框聚類引入的 CIoU 做聚類,但是沒(méi)法解決梯度彌散的問(wèn)題,所以退其次用了 DIoU
def DIoU_dist(boxes, anchor):
""" 以 DIoU 為聚類準(zhǔn)則的距離計(jì)算函數(shù)
boxes: 形如 [n_sample, 2] 的 tensor
anchor: 形如 [n_cluster, 2] 的 tensor"""
n_sample = boxes.shape[0]
n_cluster = anchor.shape[0]
dist = Eu_dist(boxes, anchor)
# 計(jì)算歐式距離
union_inter = torch.prod(boxes, dim=1).reshape(-1, 1) + torch.prod(anchor, dim=1).reshape(1, -1)
boxes = boxes.unsqueeze(1).repeat(1, n_cluster, 1)
anchor = anchor.unsqueeze(0).repeat(n_sample, 1, 1)
compare = torch.stack([boxes, anchor], dim=2)
# 組合檢測(cè)框與 anchor 的信息
diag = torch.sum(compare.max(dim=2)[0] ** 2, dim=2)
dist /= diag
# 計(jì)算外接矩形的對(duì)角線長(zhǎng)度
inter = torch.prod(compare.min(dim=2)[0], dim=2)
iou = inter / (union_inter - inter)
# 計(jì)算 IoU
dist += 1 - iou
return dist我提取了 DroneVehicle 數(shù)據(jù)集的 650156 個(gè)預(yù)測(cè)框的尺寸做聚類,在這個(gè)過(guò)程中發(fā)現(xiàn)因?yàn)樾〕叽绲念A(yù)測(cè)框過(guò)多,導(dǎo)致聚類中心聚集在原點(diǎn)附近。所以對(duì) loss 函數(shù)做了改進(jìn):先分類,再計(jì)算每個(gè)分類下的最大距離之和

橫軸表示檢測(cè)框的寬度,縱軸表示檢測(cè)框的高度,其數(shù)值都是相對(duì)于原圖尺寸的比例。若原圖尺寸為 608 * 608,則得到的 9 個(gè)先驗(yàn)框?yàn)椋?/p>
| [ 2, 3 ] | [ 9, 13 ] | [ 19, 35 ] |
| [ 10, 76 ] | [ 60, 14 ] | [ 25, 134 ] |
| [ 167, 25 ] | [ 115, 54 ] | [ 70, 176 ] |
總結(jié)
到此這篇關(guān)于Python自定義指標(biāo)聚類的文章就介紹到這了,更多相關(guān)Python自定義指標(biāo)聚類內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python面經(jīng)之16個(gè)高頻面試問(wèn)題總結(jié)
這篇文章主要給大家介紹了關(guān)于Python面經(jīng)之16個(gè)高頻面試問(wèn)題的相關(guān)資料,幫助大家回顧基礎(chǔ)知識(shí),了解面試套路,對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2022-03-03
關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解
今天小編就為大家分享一篇關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-01-01
Python中tkinter+MySQL實(shí)現(xiàn)增刪改查
這篇文章主要介紹了Python中tkinter+MySQL實(shí)現(xiàn)增刪改查,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04
python 監(jiān)控logcat關(guān)鍵字功能
python圖片處理庫(kù)Pillow實(shí)現(xiàn)簡(jiǎn)單PS功能
通過(guò)實(shí)例解析python subprocess模塊原理及用法
Python使用爬蟲(chóng)爬取靜態(tài)網(wǎng)頁(yè)圖片的方法詳解

