亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

Python自定義指標(biāo)聚類實(shí)例代碼

 更新時(shí)間:2022年02月28日 10:49:08   作者:荷碧·TZ  
K-means算法是最為經(jīng)典的基于劃分的聚類方法,是十大經(jīng)典數(shù)據(jù)挖掘算法之一,下面這篇文章主要給大家介紹了關(guān)于Python自定義指標(biāo)聚類的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下

前言

最近在研究 Yolov2 論文的時(shí)候,發(fā)現(xiàn)作者在做先驗(yàn)框聚類使用的指標(biāo)并非歐式距離,而是IOU。在找了很多資料之后,基本確定 Python 沒(méi)有自定義指標(biāo)聚類的函數(shù),所以打算自己做一個(gè)

設(shè)訓(xùn)練集的 shape 是 [n_sample, n_feature],基本思路是:

  • 簇中心初始化:第 1 個(gè)簇中心取樣本的特征均值,shape = [n_feature, ];從第 2 個(gè)簇中心開(kāi)始,用距離函數(shù) (自定義) 計(jì)算每個(gè)樣本到最近中心點(diǎn)的距離,歸一化后作為選取下一個(gè)簇中心的概率 —— 迭代到選取到足夠的簇中心為止
  • 簇中心調(diào)整:訓(xùn)練多輪,每一輪以樣本點(diǎn)到最近中心點(diǎn)的距離之和作為 loss,梯度下降法 + Adam 優(yōu)化器逼近最優(yōu)解,在 loss 浮動(dòng)值小于閾值的次數(shù)達(dá)到一定值時(shí)停止訓(xùn)練

因?yàn)樵O(shè)計(jì)之初就打算使用自定義距離函數(shù),所以求導(dǎo)是很大的難題。筆者不才,最終決定借助 PyTorch 自動(dòng)求導(dǎo)的天然優(yōu)勢(shì)

先給出歐式距離的計(jì)算函數(shù)

def Eu_dist(data, center):
    """ 以 歐氏距離 為聚類準(zhǔn)則的距離計(jì)算函數(shù)
        data: 形如 [n_sample, n_feature] 的 tensor
        center: 形如 [n_cluster, n_feature] 的 tensor"""
    data = data.unsqueeze(1)
    center = center.unsqueeze(0)
    dist = ((data - center) ** 2).sum(dim=2)
    return dist

然后就是聚類器的代碼:使用時(shí)只需關(guān)注 __init__、fit、classify 函數(shù)

import torch
import numpy as np
import matplotlib.pyplot as plt
Adam = torch.optim.Adam
 
def get_progress(current, target, bar_len=30):
    """ current: 當(dāng)前完成任務(wù)數(shù)
        target: 任務(wù)總數(shù)
        bar_len: 進(jìn)度條長(zhǎng)度
        return: 進(jìn)度條字符串"""
    assert current <= target
    percent = round(current / target * 100, 1)
    unit = 100 / bar_len
    solid = int(percent / unit)
    hollow = bar_len - solid
    return "■" * solid + "□" * hollow + f" {current}/{target}({percent}%)"
 
 
class Cluster:
    """ 聚類器
        n_cluster: 簇中心數(shù)
        dist_fun: 距離計(jì)算函數(shù)
            kwargs:
                data: 形如 [n_sample, n_feather] 的 tensor
                center: 形如 [n_cluster, n_feature] 的 tensor
            return: 形如 [n_sample, n_cluster] 的 tensor
        init: 初始簇中心
        max_iter: 最大迭代輪數(shù)
        lr: 中心點(diǎn)坐標(biāo)學(xué)習(xí)率
        stop_thresh: 停止訓(xùn)練的loss浮動(dòng)閾值
        cluster_centers_: 聚類中心
        labels_: 聚類結(jié)果"""
 
    def __init__(self, n_cluster, dist_fun, init=None, max_iter=300, lr=0.08, stop_thresh=1e-4):
        self._n_cluster = n_cluster
        self._dist_fun = dist_fun
        self._max_iter = max_iter
        self._lr = lr
        self._stop_thresh = stop_thresh
        # 初始化參數(shù)
        self.cluster_centers_ = None if init is None else torch.FloatTensor(init)
        self.labels_ = None
        self._bar_len = 20
 
    def fit(self, data):
        """ data: 形如 [n_sample, n_feature] 的 tensor
            return: loss浮動(dòng)日志"""
        if self.cluster_centers_ is None:
            self._init_cluster(data, self._max_iter // 5)
        log = self._train(data, self._max_iter, self._lr)
        # 開(kāi)始若干輪次的訓(xùn)練,得到loss浮動(dòng)日志
        return log
 
    def classify(self, data, show=False):
        """ data: 形如 [n_sample, n_feature] 的 tensor
            show: 繪制分類結(jié)果
            return: 分類標(biāo)簽"""
        dist = self._dist_fun(data, self.cluster_centers_)
        self.labels_ = dist.argmin(axis=1)
        # 將標(biāo)簽加載到實(shí)例屬性
        if show:
            for idx in range(self._n_cluster):
                container = data[self.labels_ == idx]
                plt.scatter(container[:, 0], container[:, 1], alpha=0.7)
            plt.scatter(self.cluster_centers_[:, 0], self.cluster_centers_[:, 1], c="gold", marker="p", s=50)
            plt.show()
        return self.labels_
 
    def _init_cluster(self, data, epochs):
        self.cluster_centers_ = data.mean(dim=0).reshape(1, -1)
        for idx in range(1, self._n_cluster):
            dist = np.array(self._dist_fun(data, self.cluster_centers_).min(dim=1)[0])
            new_cluster = data[np.random.choice(range(data.shape[0]), p=dist / dist.sum())].reshape(1, -1)
            # 取新的中心點(diǎn)
            self.cluster_centers_ = torch.cat([self.cluster_centers_, new_cluster], dim=0)
            progress = get_progress(idx, self._n_cluster, bar_len=self._n_cluster if self._n_cluster <= self._bar_len else self._bar_len)
            print(f"\rCluster Init: {progress}", end="")
            self._train(data, epochs, self._lr * 2.5, init=True)
            # 初始化簇中心時(shí)使用較大的lr
 
    def _train(self, data, epochs, lr, init=False):
        center = self.cluster_centers_.cuda()
        center.requires_grad = True
        data = data.cuda()
        optimizer = Adam([center], lr=lr)
        # 將中心數(shù)據(jù)加載到 GPU 上
        init_patience = int(epochs ** 0.5)
        patience = init_patience
        update_log = []
        min_loss = np.inf
        for epoch in range(epochs):
            # 對(duì)樣本分類并更新中心點(diǎn)
            sample_dist = self._dist_fun(data, center).min(dim=1)
            self.labels_ = sample_dist[1]
            loss = sum([sample_dist[0][self.labels_ == idx].mean() for idx in range(len(center))])
            # loss 函數(shù): 所有樣本到中心點(diǎn)的最小距離和 - 中心點(diǎn)間的最小間隔
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # 反向傳播梯度更新中心點(diǎn)
            loss = loss.item()
            progress = min_loss - loss
            update_log.append(progress)
            if progress > 0:
                self.cluster_centers_ = center.cpu().detach()
                min_loss = loss
                # 脫離計(jì)算圖后記錄中心點(diǎn)
            if progress < self._stop_thresh:
                patience -= 1
                # 耐心值減少
                if patience < 0:
                    break
                    # 耐心值歸零時(shí)退出
            else:
                patience = init_patience
                # 恢復(fù)耐心值
            progress = get_progress(init_patience - patience, init_patience, bar_len=self._bar_len)
            if not init:
                print(f"\rCluster: {progress}\titer: {epoch + 1}", end="")
        if not init:
            print("")
        return torch.FloatTensor(update_log)

與KMeans++比較

KMeans++ 是以歐式距離為聚類準(zhǔn)則的經(jīng)典聚類算法。在 iris 數(shù)據(jù)集上,KMeans++ 遠(yuǎn)遠(yuǎn)快于我的聚類器。但在我反復(fù)對(duì)比測(cè)試的幾輪里,我的聚類器精度也是不差的 —— 可以看到下圖里的聚類結(jié)果完全一致

 KMeans++My Cluster
Cost145 ms1597 ms
Center

[[5.9016, 2.7484, 4.3935, 1.4339],

[5.0060, 3.4280, 1.4620, 0.2460],
[6.8500, 3.0737, 5.7421, 2.0711]]

[[5.9016, 2.7485, 4.3934, 1.4338],
[5.0063, 3.4284, 1.4617, 0.2463],
[6.8500, 3.0741, 5.7420, 2.0714]]

雖然速度方面與老牌算法對(duì)比的確不行,但是我的這個(gè)聚類器最大的亮點(diǎn)還是自定義距離函數(shù)

Yolo 檢測(cè)框聚類

本來(lái)想用 Yolov4 檢測(cè)框聚類引入的 CIoU 做聚類,但是沒(méi)法解決梯度彌散的問(wèn)題,所以退其次用了 DIoU

def DIoU_dist(boxes, anchor):
    """ 以 DIoU 為聚類準(zhǔn)則的距離計(jì)算函數(shù)
        boxes: 形如 [n_sample, 2] 的 tensor
        anchor: 形如 [n_cluster, 2] 的 tensor"""
    n_sample = boxes.shape[0]
    n_cluster = anchor.shape[0]
    dist = Eu_dist(boxes, anchor)
    # 計(jì)算歐式距離
    union_inter = torch.prod(boxes, dim=1).reshape(-1, 1) + torch.prod(anchor, dim=1).reshape(1, -1)
    boxes = boxes.unsqueeze(1).repeat(1, n_cluster, 1)
    anchor = anchor.unsqueeze(0).repeat(n_sample, 1, 1)
    compare = torch.stack([boxes, anchor], dim=2)
    # 組合檢測(cè)框與 anchor 的信息
    diag = torch.sum(compare.max(dim=2)[0] ** 2, dim=2)
    dist /= diag
    # 計(jì)算外接矩形的對(duì)角線長(zhǎng)度
    inter = torch.prod(compare.min(dim=2)[0], dim=2)
    iou = inter / (union_inter - inter)
    # 計(jì)算 IoU
    dist += 1 - iou
    return dist

我提取了 DroneVehicle 數(shù)據(jù)集的 650156 個(gè)預(yù)測(cè)框的尺寸做聚類,在這個(gè)過(guò)程中發(fā)現(xiàn)因?yàn)樾〕叽绲念A(yù)測(cè)框過(guò)多,導(dǎo)致聚類中心聚集在原點(diǎn)附近。所以對(duì) loss 函數(shù)做了改進(jìn):先分類,再計(jì)算每個(gè)分類下的最大距離之和

橫軸表示檢測(cè)框的寬度,縱軸表示檢測(cè)框的高度,其數(shù)值都是相對(duì)于原圖尺寸的比例。若原圖尺寸為 608 * 608,則得到的 9 個(gè)先驗(yàn)框?yàn)椋?/p>

[ 2,  3 ][ 9,  13 ][ 19,  35 ]
[ 10,  76 ][ 60,  14 ][ 25,  134 ]
[ 167,  25 ][ 115,  54 ][ 70, 176 ]

總結(jié)

到此這篇關(guān)于Python自定義指標(biāo)聚類的文章就介紹到這了,更多相關(guān)Python自定義指標(biāo)聚類內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python 監(jiān)控logcat關(guān)鍵字功能

    python 監(jiān)控logcat關(guān)鍵字功能

    這篇文章主要介紹了python 監(jiān)控logcat關(guān)鍵字功能,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-09-09
  • python圖片處理庫(kù)Pillow實(shí)現(xiàn)簡(jiǎn)單PS功能

    python圖片處理庫(kù)Pillow實(shí)現(xiàn)簡(jiǎn)單PS功能

    Python 屆處理圖片最強(qiáng)的庫(kù)是 PIL(Python Image Library),但由于該庫(kù)只支持 2.x 版本,在此基礎(chǔ)上做了擴(kuò)展,出了一個(gè)兼容 3.x 的版本也就是 Pillow,因此,我們今天要用的庫(kù)就是Pillow
    2021-11-11
  • Python數(shù)據(jù)類型詳解(一)字符串

    Python數(shù)據(jù)類型詳解(一)字符串

    簡(jiǎn)單的說(shuō)字符串String就是使用引號(hào)定義的一組可以包含數(shù)字,字母,符號(hào)(非特殊系統(tǒng)符號(hào))的集合。今天我們就來(lái)詳細(xì)探討下Python數(shù)據(jù)類型中的字符串
    2016-05-05
  • 通過(guò)實(shí)例解析python subprocess模塊原理及用法

    通過(guò)實(shí)例解析python subprocess模塊原理及用法

    這篇文章主要介紹了通過(guò)實(shí)例解析python subprocess模塊原理及用法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-10-10
  • Python使用爬蟲(chóng)爬取靜態(tài)網(wǎng)頁(yè)圖片的方法詳解

    Python使用爬蟲(chóng)爬取靜態(tài)網(wǎng)頁(yè)圖片的方法詳解

    這篇文章主要介紹了Python使用爬蟲(chóng)爬取靜態(tài)網(wǎng)頁(yè)圖片的方法,較為詳細(xì)的說(shuō)明了爬蟲(chóng)的原理,并結(jié)合實(shí)例形式分析了Python使用爬蟲(chóng)來(lái)爬取靜態(tài)網(wǎng)頁(yè)圖片的相關(guān)操作技巧,需要的朋友可以參考下
    2018-06-06
  • 最新評(píng)論