Python自定義指標(biāo)聚類實(shí)例代碼
前言
最近在研究 Yolov2 論文的時(shí)候,發(fā)現(xiàn)作者在做先驗(yàn)框聚類使用的指標(biāo)并非歐式距離,而是IOU。在找了很多資料之后,基本確定 Python 沒(méi)有自定義指標(biāo)聚類的函數(shù),所以打算自己做一個(gè)
設(shè)訓(xùn)練集的 shape 是 [n_sample, n_feature],基本思路是:
- 簇中心初始化:第 1 個(gè)簇中心取樣本的特征均值,shape = [n_feature, ];從第 2 個(gè)簇中心開(kāi)始,用距離函數(shù) (自定義) 計(jì)算每個(gè)樣本到最近中心點(diǎn)的距離,歸一化后作為選取下一個(gè)簇中心的概率 —— 迭代到選取到足夠的簇中心為止
- 簇中心調(diào)整:訓(xùn)練多輪,每一輪以樣本點(diǎn)到最近中心點(diǎn)的距離之和作為 loss,梯度下降法 + Adam 優(yōu)化器逼近最優(yōu)解,在 loss 浮動(dòng)值小于閾值的次數(shù)達(dá)到一定值時(shí)停止訓(xùn)練
因?yàn)樵O(shè)計(jì)之初就打算使用自定義距離函數(shù),所以求導(dǎo)是很大的難題。筆者不才,最終決定借助 PyTorch 自動(dòng)求導(dǎo)的天然優(yōu)勢(shì)
先給出歐式距離的計(jì)算函數(shù)
def Eu_dist(data, center): """ 以 歐氏距離 為聚類準(zhǔn)則的距離計(jì)算函數(shù) data: 形如 [n_sample, n_feature] 的 tensor center: 形如 [n_cluster, n_feature] 的 tensor""" data = data.unsqueeze(1) center = center.unsqueeze(0) dist = ((data - center) ** 2).sum(dim=2) return dist
然后就是聚類器的代碼:使用時(shí)只需關(guān)注 __init__、fit、classify 函數(shù)
import torch import numpy as np import matplotlib.pyplot as plt Adam = torch.optim.Adam def get_progress(current, target, bar_len=30): """ current: 當(dāng)前完成任務(wù)數(shù) target: 任務(wù)總數(shù) bar_len: 進(jìn)度條長(zhǎng)度 return: 進(jìn)度條字符串""" assert current <= target percent = round(current / target * 100, 1) unit = 100 / bar_len solid = int(percent / unit) hollow = bar_len - solid return "■" * solid + "□" * hollow + f" {current}/{target}({percent}%)" class Cluster: """ 聚類器 n_cluster: 簇中心數(shù) dist_fun: 距離計(jì)算函數(shù) kwargs: data: 形如 [n_sample, n_feather] 的 tensor center: 形如 [n_cluster, n_feature] 的 tensor return: 形如 [n_sample, n_cluster] 的 tensor init: 初始簇中心 max_iter: 最大迭代輪數(shù) lr: 中心點(diǎn)坐標(biāo)學(xué)習(xí)率 stop_thresh: 停止訓(xùn)練的loss浮動(dòng)閾值 cluster_centers_: 聚類中心 labels_: 聚類結(jié)果""" def __init__(self, n_cluster, dist_fun, init=None, max_iter=300, lr=0.08, stop_thresh=1e-4): self._n_cluster = n_cluster self._dist_fun = dist_fun self._max_iter = max_iter self._lr = lr self._stop_thresh = stop_thresh # 初始化參數(shù) self.cluster_centers_ = None if init is None else torch.FloatTensor(init) self.labels_ = None self._bar_len = 20 def fit(self, data): """ data: 形如 [n_sample, n_feature] 的 tensor return: loss浮動(dòng)日志""" if self.cluster_centers_ is None: self._init_cluster(data, self._max_iter // 5) log = self._train(data, self._max_iter, self._lr) # 開(kāi)始若干輪次的訓(xùn)練,得到loss浮動(dòng)日志 return log def classify(self, data, show=False): """ data: 形如 [n_sample, n_feature] 的 tensor show: 繪制分類結(jié)果 return: 分類標(biāo)簽""" dist = self._dist_fun(data, self.cluster_centers_) self.labels_ = dist.argmin(axis=1) # 將標(biāo)簽加載到實(shí)例屬性 if show: for idx in range(self._n_cluster): container = data[self.labels_ == idx] plt.scatter(container[:, 0], container[:, 1], alpha=0.7) plt.scatter(self.cluster_centers_[:, 0], self.cluster_centers_[:, 1], c="gold", marker="p", s=50) plt.show() return self.labels_ def _init_cluster(self, data, epochs): self.cluster_centers_ = data.mean(dim=0).reshape(1, -1) for idx in range(1, self._n_cluster): dist = np.array(self._dist_fun(data, self.cluster_centers_).min(dim=1)[0]) new_cluster = data[np.random.choice(range(data.shape[0]), p=dist / dist.sum())].reshape(1, -1) # 取新的中心點(diǎn) self.cluster_centers_ = torch.cat([self.cluster_centers_, new_cluster], dim=0) progress = get_progress(idx, self._n_cluster, bar_len=self._n_cluster if self._n_cluster <= self._bar_len else self._bar_len) print(f"\rCluster Init: {progress}", end="") self._train(data, epochs, self._lr * 2.5, init=True) # 初始化簇中心時(shí)使用較大的lr def _train(self, data, epochs, lr, init=False): center = self.cluster_centers_.cuda() center.requires_grad = True data = data.cuda() optimizer = Adam([center], lr=lr) # 將中心數(shù)據(jù)加載到 GPU 上 init_patience = int(epochs ** 0.5) patience = init_patience update_log = [] min_loss = np.inf for epoch in range(epochs): # 對(duì)樣本分類并更新中心點(diǎn) sample_dist = self._dist_fun(data, center).min(dim=1) self.labels_ = sample_dist[1] loss = sum([sample_dist[0][self.labels_ == idx].mean() for idx in range(len(center))]) # loss 函數(shù): 所有樣本到中心點(diǎn)的最小距離和 - 中心點(diǎn)間的最小間隔 loss.backward() optimizer.step() optimizer.zero_grad() # 反向傳播梯度更新中心點(diǎn) loss = loss.item() progress = min_loss - loss update_log.append(progress) if progress > 0: self.cluster_centers_ = center.cpu().detach() min_loss = loss # 脫離計(jì)算圖后記錄中心點(diǎn) if progress < self._stop_thresh: patience -= 1 # 耐心值減少 if patience < 0: break # 耐心值歸零時(shí)退出 else: patience = init_patience # 恢復(fù)耐心值 progress = get_progress(init_patience - patience, init_patience, bar_len=self._bar_len) if not init: print(f"\rCluster: {progress}\titer: {epoch + 1}", end="") if not init: print("") return torch.FloatTensor(update_log)
與KMeans++比較
KMeans++ 是以歐式距離為聚類準(zhǔn)則的經(jīng)典聚類算法。在 iris 數(shù)據(jù)集上,KMeans++ 遠(yuǎn)遠(yuǎn)快于我的聚類器。但在我反復(fù)對(duì)比測(cè)試的幾輪里,我的聚類器精度也是不差的 —— 可以看到下圖里的聚類結(jié)果完全一致
KMeans++ | My Cluster | |
Cost | 145 ms | 1597 ms |
Center | [[5.9016, 2.7484, 4.3935, 1.4339], [5.0060, 3.4280, 1.4620, 0.2460], | [[5.9016, 2.7485, 4.3934, 1.4338], |
雖然速度方面與老牌算法對(duì)比的確不行,但是我的這個(gè)聚類器最大的亮點(diǎn)還是自定義距離函數(shù)
Yolo 檢測(cè)框聚類
本來(lái)想用 Yolov4 檢測(cè)框聚類引入的 CIoU 做聚類,但是沒(méi)法解決梯度彌散的問(wèn)題,所以退其次用了 DIoU
def DIoU_dist(boxes, anchor): """ 以 DIoU 為聚類準(zhǔn)則的距離計(jì)算函數(shù) boxes: 形如 [n_sample, 2] 的 tensor anchor: 形如 [n_cluster, 2] 的 tensor""" n_sample = boxes.shape[0] n_cluster = anchor.shape[0] dist = Eu_dist(boxes, anchor) # 計(jì)算歐式距離 union_inter = torch.prod(boxes, dim=1).reshape(-1, 1) + torch.prod(anchor, dim=1).reshape(1, -1) boxes = boxes.unsqueeze(1).repeat(1, n_cluster, 1) anchor = anchor.unsqueeze(0).repeat(n_sample, 1, 1) compare = torch.stack([boxes, anchor], dim=2) # 組合檢測(cè)框與 anchor 的信息 diag = torch.sum(compare.max(dim=2)[0] ** 2, dim=2) dist /= diag # 計(jì)算外接矩形的對(duì)角線長(zhǎng)度 inter = torch.prod(compare.min(dim=2)[0], dim=2) iou = inter / (union_inter - inter) # 計(jì)算 IoU dist += 1 - iou return dist
我提取了 DroneVehicle 數(shù)據(jù)集的 650156 個(gè)預(yù)測(cè)框的尺寸做聚類,在這個(gè)過(guò)程中發(fā)現(xiàn)因?yàn)樾〕叽绲念A(yù)測(cè)框過(guò)多,導(dǎo)致聚類中心聚集在原點(diǎn)附近。所以對(duì) loss 函數(shù)做了改進(jìn):先分類,再計(jì)算每個(gè)分類下的最大距離之和
橫軸表示檢測(cè)框的寬度,縱軸表示檢測(cè)框的高度,其數(shù)值都是相對(duì)于原圖尺寸的比例。若原圖尺寸為 608 * 608,則得到的 9 個(gè)先驗(yàn)框?yàn)椋?/p>
[ 2, 3 ] | [ 9, 13 ] | [ 19, 35 ] |
[ 10, 76 ] | [ 60, 14 ] | [ 25, 134 ] |
[ 167, 25 ] | [ 115, 54 ] | [ 70, 176 ] |
總結(jié)
到此這篇關(guān)于Python自定義指標(biāo)聚類的文章就介紹到這了,更多相關(guān)Python自定義指標(biāo)聚類內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python面經(jīng)之16個(gè)高頻面試問(wèn)題總結(jié)
這篇文章主要給大家介紹了關(guān)于Python面經(jīng)之16個(gè)高頻面試問(wèn)題的相關(guān)資料,幫助大家回顧基礎(chǔ)知識(shí),了解面試套路,對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2022-03-03關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解
今天小編就為大家分享一篇關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-01-01Python中tkinter+MySQL實(shí)現(xiàn)增刪改查
這篇文章主要介紹了Python中tkinter+MySQL實(shí)現(xiàn)增刪改查,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04

python 監(jiān)控logcat關(guān)鍵字功能

python圖片處理庫(kù)Pillow實(shí)現(xiàn)簡(jiǎn)單PS功能

通過(guò)實(shí)例解析python subprocess模塊原理及用法

Python使用爬蟲(chóng)爬取靜態(tài)網(wǎng)頁(yè)圖片的方法詳解