python算法學(xué)習(xí)雙曲嵌入論文方法與代碼解析說(shuō)明
本篇接上一篇:python算法學(xué)習(xí)雙曲嵌入論文代碼實(shí)現(xiàn)數(shù)據(jù)集介紹
1. 方法說(shuō)明
首先學(xué)習(xí)相關(guān)的論文中的一些知識(shí),并結(jié)合進(jìn)行代碼的編寫(xiě)。文中主要使用Poincaré embedding。
對(duì)應(yīng)的python代碼為:
def dist1(vec1, vec2): # eqn1 diff_vec = vec1 - vec2 return 1 + 2 * norm(diff_vec) / ((1 - norm(vec1)) * (1 - norm(vec2)))
損失函數(shù)
我們想要尋找最優(yōu)的embedding,就需要構(gòu)建一個(gè)損失函數(shù),目標(biāo)是使得相似詞匯的embedding結(jié)果,盡可能接近,且層級(jí)越高(類別越大)的詞越靠近中心。我們需要最小化這個(gè)損失函數(shù),從而得到embedding的結(jié)果。
其實(shí)在傳統(tǒng)的詞嵌入中,我們也是用上述的損失函數(shù),但距離選用的是余弦距離。
梯度下降
后面將使用梯度下降方法進(jìn)行求解迭代。
由于是將歐氏空間計(jì)算得到的梯度在黎曼空間中進(jìn)行迭代,由上文的(1)式,我們有:
梯度求解
202111595310129
對(duì)應(yīng)的更新函數(shù)在Python中設(shè)置如下:
# 范數(shù)計(jì)算 def norm(x): return np.dot(x, x) # 距離函數(shù)對(duì)\theta求偏導(dǎo) def compute_distance_gradients(theta, x, gamma): alpha = (1.0 - np.dot(theta, theta)) norm_x = norm(x) beta = (1 - norm_x) c_ = 4.0 / (alpha * beta * sqrt(gamma ** 2 - 1)) return c_ * ((norm_x - 2 * np.dot(theta, x) + 1) / alpha * theta - x) # 更新公式 def update(emb, grad, lr): c_ = (1 - norm(emb)) ** 2 / 4 upd = lr * c_ * grad emb = emb - upd if (norm(emb) >= 1): emb = emb / sqrt(norm(emb)) - eps return emb
至此,我們就可以開(kāi)始寫(xiě)一個(gè)完整的訓(xùn)練過(guò)程了。在這之前,再補(bǔ)充一個(gè)繪圖函數(shù)(可以看embedding的實(shí)際訓(xùn)練情況):
def plotall(ii): fig = plt.figure(figsize=(10, 10)) # 繪制所有節(jié)點(diǎn) for a in emb: plt.plot(emb[a][0], emb[a][1], marker = 'o', color = [levelOfNode[a]/(last_level+1),levelOfNode[a]/(last_level+1),levelOfNode[a]/(last_level+1)]) for a in network: for b in network[a]: plt.plot([emb[a][0], emb[b][0]], [emb[a][1], emb[b][1]], color = [levelOfNode[a]/(last_level+1),levelOfNode[a]/(last_level+1),levelOfNode[a]/(last_level+1)]) circle = plt.Circle((0, 0), 1, color='y', fill=False) plt.gcf().gca().add_artist(circle) plt.xlim(-1, 1) plt.ylim(-1, 1) fig.savefig('~/GitHub/hyperE/fig/' + str(last_level) + '_' + str(ii) + '.png', dpi = 200)
2. 代碼訓(xùn)練過(guò)程
首先初始化embeddings,這里按照論文中寫(xiě)的,用 ( − 0.001 , 0.001 ) (-0.001, 0.001) (−0.001,0.001)間的均勻分布進(jìn)行隨機(jī)初始化:
emb = {} for node in levelOfNode: emb[node] = np.random.uniform(low = -0.001, high = 0.001, size = (2, ))
下面設(shè)置學(xué)習(xí)率等參數(shù):
vocab = list(emb.keys()) eps = 1e-5 lr = 0.1 # 學(xué)習(xí)率 num_negs = 10 # 負(fù)樣本個(gè)數(shù)
接下來(lái)開(kāi)始正式迭代,具體每一行的含義均在注釋中有進(jìn)行說(shuō)明:
# 繪制初始化權(quán)重 plotall("init") for epoch in range(1000): loss = [] random.shuffle(vocab) # 下面需要抽取不同的樣本:pos2 與 pos1 相關(guān);negs 不與 pos1 相關(guān) for pos1 in vocab: if not network[pos1]: # 葉子節(jié)點(diǎn)則不進(jìn)行訓(xùn)練 continue pos2 = random.choice(network[pos1]) # 隨機(jī)選取與pos1相關(guān)的節(jié)點(diǎn)pos2 dist_pos_ = dist1(emb[pos1], emb[pos2]) # 保留中間變量gamma,加速計(jì)算 dist_pos = np.arccosh(dist_pos_) # 計(jì)算pos1與pos2之間的距離 # 下面抽取負(fù)樣本組(不與pos1相關(guān)的樣本組) negs = [[pos1, pos1]] dist_negs_ = [1] dist_negs = [0] while (len(negs) < num_negs): neg = random.choice(vocab) # 保證負(fù)樣本neg與pos1沒(méi)有邊相連接 if not (neg in network[pos1] or pos1 in network[neg] or neg == pos1): dist_neg_ = dist1(emb[pos1], emb[neg]) dist_neg = np.arccosh(dist_neg_) negs.append([pos1, neg]) dist_negs_.append(dist_neg_) # 保存中間變量gamma,加速計(jì)算 dist_negs.append(dist_neg) # 針對(duì)一個(gè)樣本的損失 loss_neg = 0.0 for dist_neg in dist_negs: loss_neg += exp(-1 * dist_neg) loss.append(dist_pos + log(loss_neg)) # 損失函數(shù) 對(duì) 正樣本對(duì) 距離 d(u, v) 的導(dǎo)數(shù) grad_L_pos = -1 # 損失函數(shù) 對(duì) 負(fù)樣本對(duì) 距離 d(u, v') 的導(dǎo)數(shù) grad_L_negs = [] for dist_neg in dist_negs: grad_L_negs.append(exp(-dist_neg) / loss_neg) # 計(jì)算正樣本對(duì)中兩個(gè)樣本的embedding的更新方向 grad_pos1 = grad_L_pos * compute_distance_gradients(emb[pos1], emb[pos2], dist_pos_) grad_pos2 = grad_L_pos * compute_distance_gradients(emb[pos2], emb[pos1], dist_pos_) # 計(jì)算負(fù)樣本對(duì)中所有樣本的embedding的更新方向 grad_negs_final = [] for (grad_L_neg, neg, dist_neg_) in zip(grad_L_negs[1:], negs[1:], dist_negs_[1:]): grad_neg0 = grad_L_neg * compute_distance_gradients(emb[neg[0]], emb[neg[1]], dist_neg_) grad_neg1 = grad_L_neg * compute_distance_gradients(emb[neg[1]], emb[neg[0]], dist_neg_) grad_negs_final.append([grad_neg0, grad_neg1]) # 更新embeddings emb[pos1] = update(emb[pos1], -grad_pos1, lr) emb[pos2] = update(emb[pos2], -grad_pos2, lr) for (neg, grad_neg) in zip(negs, grad_negs_final): emb[neg[0]] = update(emb[neg[0]], -grad_neg[0], lr) emb[neg[1]] = update(emb[neg[1]], -grad_neg[1], lr) # 輸出損失 if ((epoch) % 10 == 0): print(epoch + 1, "---Loss: ", sum(loss)) # 繪制二維embeddings if ((epoch) % 100 == 0): plotall(epoch + 1)
3. 結(jié)果表現(xiàn)
結(jié)果如下所示(與論文有些不一致):
實(shí)際上應(yīng)該還是有效的,有些團(tuán)都能聚合在一起,下面是一個(gè)隨機(jī)訓(xùn)練的結(jié)果(可以看出非?;靵y):
其他參考資料
Poincaré Embeddings for Learning Hierarchical Representations
Implementing Poincaré Embeddings
models.poincare – Train and use Poincare embeddings
How to make a graph on Python describing WordNet's synsets (NLTK)
networkx.drawing.nx_pylab.draw_networkx
以上就是python算法學(xué)習(xí)雙曲嵌入論文方法與代碼解析說(shuō)明的詳細(xì)內(nèi)容,更多關(guān)于python雙曲嵌入論文方法與代碼的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python著名游戲?qū)崙?zhàn)之方塊連接 我的世界
讀萬(wàn)卷書(shū)不如行萬(wàn)里路,學(xué)的扎不扎實(shí)要通過(guò)實(shí)戰(zhàn)才能看出來(lái),本篇文章手把手帶你模仿著名游戲——我的世界,大家可以在過(guò)程中查缺補(bǔ)漏,看看自己掌握程度怎么樣2021-10-10基于Python實(shí)現(xiàn)文本文件轉(zhuǎn)Excel
Excel文件是我們常用的一種文件,在工作中使用非常頻繁。Excel中有許多強(qiáng)大工具,因此用Excel來(lái)處理文件會(huì)給我們帶來(lái)很多便捷。本文就來(lái)和大家分享一下Python實(shí)現(xiàn)文本文件轉(zhuǎn)Excel的方法,感興趣的可以了解一下2022-08-08python輸出國(guó)際象棋棋盤(pán)的實(shí)例分享
在本篇文章里小編給大家整理的是一篇關(guān)于python輸出國(guó)際象棋棋盤(pán)的實(shí)例詳解,有興趣的朋友們可以參考下。2020-11-11利用python模擬實(shí)現(xiàn)POST請(qǐng)求提交圖片的方法
最近在利用python做接口測(cè)試,其中有個(gè)上傳圖片的接口,在網(wǎng)上各種搜索,各種嘗試。下面這篇文章主要給大家介紹了關(guān)于利用python模擬實(shí)現(xiàn)POST請(qǐng)求提交圖片的相關(guān)資料,需要的朋友可以參考借鑒,下面來(lái)一起看看吧。2017-07-07python數(shù)據(jù)擬合之scipy.optimize.curve_fit解讀
這篇文章主要介紹了python數(shù)據(jù)擬合之scipy.optimize.curve_fit解讀,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-12-12500行python代碼實(shí)現(xiàn)飛機(jī)大戰(zhàn)
這篇文章主要為大家詳細(xì)介紹了500行python代碼實(shí)現(xiàn)飛機(jī)大戰(zhàn),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2020-04-04python實(shí)現(xiàn)全盤(pán)掃描搜索功能的方法
今天小編就為大家分享一篇python實(shí)現(xiàn)全盤(pán)掃描搜索功能的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-02-02