Pytorch distributed 多卡并行載入模型操作
一、Pytorch distributed 多卡并行載入模型
這次來介紹下如何載入模型。
目前沒有找到官方的distribute 載入模型的方式,所以采用如下方式。
大部分情況下,我們在測試時(shí)不需要多卡并行計(jì)算。
所以,我在測試時(shí)只使用單卡。
from collections import OrderedDict device = torch.device("cuda") model = DGCNN(args).to(device) #自己的模型 state_dict = torch.load(args.model_path) #存放模型的位置 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model.load_state_dict (new_state_dict)
二、pytorch DistributedParallel進(jìn)行單機(jī)多卡訓(xùn)練
One_導(dǎo)入庫:
import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler
Two_進(jìn)程初始化:
parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) # 添加必要參數(shù) # local_rank:系統(tǒng)自動(dòng)賦予的進(jìn)程編號(hào),可以利用該編號(hào)控制打印輸出以及設(shè)置device torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile', rank=local_rank, world_size=world_size) # world_size:所創(chuàng)建的進(jìn)程數(shù),也就是所使用的GPU數(shù)量 # (初始化設(shè)置詳見參考文檔)
Three_數(shù)據(jù)分發(fā):
dataset = datasets.ImageFolder(dataPath) data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size) # 使用DistributedSampler來為各個(gè)進(jìn)程分發(fā)數(shù)據(jù),其中num_replicas與world_size保持一致,用于將數(shù)據(jù)集等分成不重疊的數(shù)個(gè)子集 dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler) # 在Dataloader中指定sampler時(shí),其中的shuffle必須為False,而DistributedSampler中的shuffle項(xiàng)默認(rèn)為True,因此訓(xùn)練過程默認(rèn)執(zhí)行shuffle
Four_網(wǎng)絡(luò)模型:
torch.cuda.set_device(local_rank) device = torch.device('cuda:'+f'{local_rank}') # 設(shè)置每個(gè)進(jìn)程對(duì)應(yīng)的GPU設(shè)備 D = Model() D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device) # 由于在訓(xùn)練過程中各卡的前向后向傳播均獨(dú)立進(jìn)行,因此無法進(jìn)行統(tǒng)一的批歸一化,如果想要將各卡的輸出統(tǒng)一進(jìn)行批歸一化,需要將模型中的BN轉(zhuǎn)換成SyncBN D = torch.nn.parallel.DistributedDataParallel( D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank) # 如果有forward的返回值如果不在計(jì)算loss的計(jì)算圖里,那么需要find_unused_parameters=True,即返回值不進(jìn)入backward去算grad,也不需要在不同進(jìn)程之間進(jìn)行通信。
Five_迭代:
data_sampler.set_epoch(epoch) # 每個(gè)epoch需要為sampler設(shè)置當(dāng)前epoch
Six_加載:
dist.barrier() D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu')) dist.barrier() # 加載模型前后用dist.barrier()來同步不同進(jìn)程間的快慢
Seven_啟動(dòng):
CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2 # 用-m torch.distributed.launch啟動(dòng),nproc_per_node為所使用的卡數(shù),batchsize設(shè)置為每張卡各自的批大小
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
jupyter運(yùn)行時(shí)左邊一直出現(xiàn)*號(hào)問題及解決
這篇文章主要介紹了jupyter運(yùn)行時(shí)左邊一直出現(xiàn)*號(hào)問題及解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-09-09django+celery+RabbitMQ自定義多個(gè)消息隊(duì)列的實(shí)現(xiàn)
本文主要介紹了django+celery+RabbitMQ自定義多個(gè)消息隊(duì)列的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02安裝ElasticSearch搜索工具并配置Python驅(qū)動(dòng)的方法
這篇文章主要介紹了安裝ElasticSearch搜索工具并配置Python驅(qū)動(dòng)的方法,文中還介紹了其與Kibana數(shù)據(jù)顯示客戶端的配合使用,需要的朋友可以參考下2015-12-12opencv+python識(shí)別七段數(shù)碼顯示器的數(shù)字(數(shù)字識(shí)別)
本文主要介紹了opencv+python識(shí)別七段數(shù)碼顯示器的數(shù)字(數(shù)字識(shí)別),文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-01-01基于Python實(shí)現(xiàn)語音識(shí)別和語音轉(zhuǎn)文字
這篇文章主要為大家詳細(xì)介紹了如何利用Python實(shí)現(xiàn)語音識(shí)別和語音轉(zhuǎn)文字功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2022-09-09