Pytorch?PyG實(shí)現(xiàn)EdgePool圖分類
EdgePool簡(jiǎn)介
EdgePool是一種用于圖分類的卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network,CNN)模型。其主要思想是通過(guò) edge pooling 上下采樣優(yōu)化圖像大小,減少空間復(fù)雜度,提高分類性能。
實(shí)現(xiàn)步驟
數(shù)據(jù)準(zhǔn)備
一般來(lái)講,在構(gòu)建較大規(guī)模數(shù)據(jù)集時(shí),我們都需要對(duì)數(shù)據(jù)進(jìn)行規(guī)范、歸一和清洗處理,以便后續(xù)語(yǔ)義分析或深度學(xué)習(xí)操作。而在圖像數(shù)據(jù)集中,則需使用特定的框架或工具庫(kù)完成。
# 導(dǎo)入MNIST數(shù)據(jù)集 from torch_geometric.datasets import MNISTSuperpixels # 加載數(shù)據(jù)、劃分訓(xùn)練集和測(cè)試集 dataset = MNISTSuperpixels(root='./mnist', transform=Compose([ToTensor(), NormalizeMeanStd()])) data = dataset[0] # 定義超級(jí)參數(shù) num_features = dataset.num_features num_classes = dataset.num_classes # 構(gòu)建訓(xùn)練集和測(cè)試集索引文件 train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8) train_mask[:60000] = 1 test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8) test_mask[60000:] = 1 # 創(chuàng)建數(shù)據(jù)加載器 train_loader = DataLoader(data[train_mask], batch_size=32, shuffle=True) test_loader = DataLoader(data[test_mask], batch_size=32, shuffle=False)
實(shí)現(xiàn)模型
在定義EdgePool模型時(shí),我們需要重新考慮網(wǎng)絡(luò)結(jié)構(gòu)中的上下采樣操作,以便讓整個(gè)網(wǎng)絡(luò)擁有更強(qiáng)大的表達(dá)能力,從而學(xué)習(xí)到更復(fù)雜的關(guān)系。
from torch.nn import Linear from torch_geometric.nn import EdgePooling class EdgePool(torch.nn.Module): def __init__(self, dataset): super(EdgePool, self).__init__() # 定義輸入與輸出維度數(shù) self.input_dim = dataset.num_features self.hidden_dim = 128 self.output_dim = 10 # 定義卷積層、歸一化層和pooling層等 self.conv1 = GCNConv(self.input_dim, self.hidden_dim) self.norm1 = BatchNorm1d(self.hidden_dim) self.pool1 = EdgePooling(self.hidden_dim) self.conv2 = GCNConv(self.hidden_dim, self.hidden_dim) self.norm2 = BatchNorm1d(self.hidden_dim) self.pool2 = EdgePooling(self.hidden_dim) self.conv3 = GCNConv(self.hidden_dim, self.hidden_dim) self.norm3 = BatchNorm1d(self.hidden_dim) self.pool3 = EdgePooling(self.hidden_dim) self.lin = torch.nn.Linear(self.hidden_dim, self.output_dim) def forward(self, x, edge_index, batch): x = F.relu(self.norm1(self.conv1(x, edge_index))) x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch) x = F.relu(self.norm2(self.conv2(x, edge_index))) x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) x = F.relu(self.norm3(self.conv3(x, edge_index))) x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch) x = global_mean_pool(x, batch) x = self.lin(x) return x
在上述代碼中,我們使用了不同的卷積層、池化層和全連接層等神經(jīng)網(wǎng)絡(luò)功能塊來(lái)構(gòu)建EdgePool模型。其中,每個(gè) GCNConv 層被保持為128的隱藏尺寸;BatchNorm1d是一種旨在提高收斂速度并增強(qiáng)網(wǎng)絡(luò)泛化能力的方法;EdgePooling是一種在 GraphConvolution 上附加的特殊類別,它將給定圖下采樣至其一半的大小,并返回縮小后的圖與兩個(gè)跟蹤full-graph-to-pool雙向映射(keep and senders)的 edge index(edgendarcs)。 在這種情況下傳遞 None ,表明 batch
未更改。
模型訓(xùn)練
在定義好 EdgePool 網(wǎng)絡(luò)結(jié)構(gòu)之后,需要指定合適的優(yōu)化器、損失函數(shù),并控制訓(xùn)練輪數(shù)、批量大小與學(xué)習(xí)率等超參數(shù)。同時(shí)還要記錄大量日志信息,方便后期跟蹤和駕駛員。
# 定義訓(xùn)練計(jì)劃,包括損失函數(shù)、優(yōu)化器及迭代次數(shù)等 train_epochs = 50 learning_rate = 0.01 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(edge_pool.parameters(), lr=learning_rate) losses_per_epoch = [] accuracies_per_epoch = [] for epoch in range(train_epochs): running_loss = 0.0 running_corrects = 0.0 count = 0.0 for samples in train_loader: optimizer.zero_grad() x, edge_index, batch = samples.x, samples.edge_index, samples.batch out = edge_pool(x, edge_index, batch) label = samples.y loss = criterion(out, label) loss.backward() optimizer.step() running_loss += loss.item() / len(train_loader.dataset) pred = out.argmax(dim=1) running_corrects += pred.eq(label).sum().item() / len(train_loader.dataset) count += 1 losses_per_epoch.append(running_loss) accuracies_per_epoch.append(running_corrects) if (epoch + 1) % 10 == 0: print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format( epoch + 1, train_epochs, running_loss, running_corrects))
在訓(xùn)練過(guò)程中,我們遍歷了每個(gè)批次的數(shù)據(jù),并通過(guò)反向傳播算法進(jìn)行優(yōu)化,并更新了 loss 和 accuracy 輸出值。 同時(shí)方便可視化與記錄,需要將訓(xùn)練過(guò)程中的 loss 和 accuracy 輸出到相應(yīng)的容器中,以便后期進(jìn)行分析和處理。
以上就是Pytorch PyG實(shí)現(xiàn)EdgePool圖分類的詳細(xì)內(nèi)容,更多關(guān)于Pytorch PyG EdgePool圖分類的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python?NumPy教程之?dāng)?shù)組的基本操作詳解
Numpy?中的數(shù)組是一個(gè)元素表(通常是數(shù)字),所有元素類型相同,由正整數(shù)元組索引。本文將通過(guò)一些示例詳細(xì)講一下NumPy中數(shù)組的一些基本操作,需要的可以參考一下2022-08-08Python調(diào)用訊飛語(yǔ)音合成API接口來(lái)實(shí)現(xiàn)文字轉(zhuǎn)語(yǔ)音
這篇文章主要為大家介紹了Python調(diào)用訊飛語(yǔ)音合成API接口來(lái)實(shí)現(xiàn)文字轉(zhuǎn)語(yǔ)音方法示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-04-04K近鄰法(KNN)相關(guān)知識(shí)總結(jié)以及如何用python實(shí)現(xiàn)
這篇文章主要介紹了K近鄰法(KNN)相關(guān)知識(shí)總結(jié)以及如何用python實(shí)現(xiàn),幫助大家更好的利用python實(shí)現(xiàn)機(jī)器學(xué)習(xí),感興趣的朋友可以了解下2021-01-01python 圖像增強(qiáng)算法實(shí)現(xiàn)詳解
這篇文章主要介紹了python 圖像增強(qiáng)算法實(shí)現(xiàn)詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-01-01python?魔法方法之?__?slots?__的實(shí)現(xiàn)
本文主要介紹了python?魔法方法之?__?slots?__的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03Pandas中字符串和時(shí)間轉(zhuǎn)換與格式化的實(shí)現(xiàn)
本文主要介紹了Pandas中字符串和時(shí)間轉(zhuǎn)換與格式化的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-01-01Python獲取當(dāng)前路徑實(shí)現(xiàn)代碼
這篇文章主要介紹了 Python獲取當(dāng)前路徑實(shí)現(xiàn)代碼的相關(guān)資料,需要的朋友可以參考下2017-05-05詳解python中TCP協(xié)議中的粘包問(wèn)題
這篇文章主要介紹了python中TCP協(xié)議中的粘包問(wèn)題,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-03-03python數(shù)據(jù)分析必會(huì)的Pandas技巧匯總
用Python做數(shù)據(jù)分析光是掌握numpy和matplotlib可不夠,numpy雖然能夠幫我們處理處理數(shù)值型數(shù)據(jù),但很多時(shí)候,還有字符串,還有時(shí)間序列等,比如:我們通過(guò)爬蟲(chóng)獲取到了存儲(chǔ)在數(shù)據(jù)庫(kù)中的數(shù)據(jù),一些Pandas必會(huì)的用法,讓你的數(shù)據(jù)分析水平更上一層樓2021-08-08