YOLOv8模型pytorch格式轉(zhuǎn)為onnx格式的步驟詳解
一、YOLOv8的Pytorch網(wǎng)絡(luò)結(jié)構(gòu)
model DetectionModel( (model): Sequential( (0): Conv( (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (1): Conv( (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (2): C2f( (cv1): Conv( (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-2): 3 x Bottleneck( (cv1): Conv( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (3): Conv( (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (4): C2f( (cv1): Conv( (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-5): 6 x Bottleneck( (cv1): Conv( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (5): Conv( (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (6): C2f( (cv1): Conv( (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-5): 6 x Bottleneck( (cv1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (7): Conv( (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (8): C2f( (cv1): Conv( (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(1280, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-2): 3 x Bottleneck( (cv1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (9): SPPF( (cv1): Conv( (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False) ) (10): Upsample(scale_factor=2.0, mode='nearest') (11): Concat() (12): C2f( (cv1): Conv( (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(1280, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-2): 3 x Bottleneck( (cv1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (13): Upsample(scale_factor=2.0, mode='nearest') (14): Concat() (15): C2f( (cv1): Conv( (conv): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(640, 256, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-2): 3 x Bottleneck( (cv1): Conv( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (16): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (17): Concat() (18): C2f( (cv1): Conv( (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(1280, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-2): 3 x Bottleneck( (cv1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (19): Conv( (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (act): SiLU(inplace=True) ) (20): Concat() (21): C2f( (cv1): Conv( (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(1280, 512, kernel_size=(1, 1), stride=(1, 1)) (act): SiLU(inplace=True) ) (m): ModuleList( (0-2): 3 x Bottleneck( (cv1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (cv2): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) ) ) ) (22): PostDetect( (cv2): ModuleList( (0): Sequential( (0): Conv( (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (1): Conv( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1-2): 2 x Sequential( (0): Conv( (conv): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (1): Conv( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (cv3): ModuleList( (0): Sequential( (0): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (2): Conv2d(256, 35, kernel_size=(1, 1), stride=(1, 1)) ) (1-2): 2 x Sequential( (0): Conv( (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (1): Conv( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (act): SiLU(inplace=True) ) (2): Conv2d(256, 35, kernel_size=(1, 1), stride=(1, 1)) ) ) (dfl): DFL( (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) )
yolov8網(wǎng)絡(luò)從1-21層與pt文件相對(duì)應(yīng)是BackBone和Neck模塊,22層是Head模塊。
二、轉(zhuǎn)ONNX步驟
2.1 yolov8官方
""" 代碼解釋 pt模型轉(zhuǎn)為onnx格式 """ import os from ultralytics import YOLO model = YOLO("weights/best.pt") success = model.export(format="onnx") print("導(dǎo)出成功!")
將pytorch轉(zhuǎn)為onnx后,pytorch支持的一系列計(jì)算就會(huì)轉(zhuǎn)為onnx所支持的算子,若沒有相對(duì)應(yīng)的就會(huì)使用其他方式進(jìn)行替換(比如多個(gè)計(jì)算替換其單個(gè))。比較常見是conv和SiLU合并成一個(gè)Conv模塊進(jìn)行。
其中,1*4*8400表示每張圖片預(yù)測(cè) 8400 個(gè)候選框,每個(gè)框有 4 個(gè)參數(shù)邊界框坐標(biāo) (x,y,w,h)。 1*35*8400類同,1和4800代表意義相同,35是類別屬性包含了其置信度概率值。
最后兩個(gè)輸出Concat操作,得到1*39*8400。最后根據(jù)這個(gè)結(jié)果去進(jìn)行后續(xù)操作。
2.2 自定義轉(zhuǎn)換
所謂的自定義轉(zhuǎn)換其實(shí)是在轉(zhuǎn)onnx時(shí),對(duì)1*39*8400多加了一系列自定義操作例如NMS等。
2.2.1 加載權(quán)重并優(yōu)化結(jié)構(gòu)
YOLOv8 = YOLO(args.weights) #替換為自己的權(quán)重 model = YOLOv8.model.fuse().eval()
2.2.2 后處理檢測(cè)模塊
def gen_anchors(feats: Tensor, strides: Tensor, grid_cell_offset: float = 0.5) -> Tuple[Tensor, Tensor]: """ 生成錨點(diǎn),并計(jì)算每個(gè)錨點(diǎn)的步幅。 參數(shù): feats (Tensor): 特征圖,通常來自不同的網(wǎng)絡(luò)層。 strides (Tensor): 每個(gè)特征圖的步幅(stride)。 grid_cell_offset (float): 網(wǎng)格單元的偏移量,默認(rèn)為0.5。 返回: Tuple[Tensor, Tensor]: 錨點(diǎn)的坐標(biāo)和對(duì)應(yīng)的步幅張量。 """ anchor_points, stride_tensor = [], [] assert feats is not None # 確保輸入的特征圖不為空 dtype, device = feats[0].dtype, feats[0].device # 獲取特征圖的數(shù)據(jù)類型和設(shè)備 # 遍歷每個(gè)特征圖,計(jì)算錨點(diǎn) for i, stride in enumerate(strides): _, _, h, w = feats[i].shape # 獲取特征圖的高(h)和寬(w) sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # 計(jì)算 x 軸上的錨點(diǎn)位置 sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # 計(jì)算 y 軸上的錨點(diǎn)位置 sy, sx = torch.meshgrid(sy, sx) # 生成網(wǎng)格坐標(biāo) anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) # 將 x 和 y 組合成坐標(biāo)點(diǎn) stride_tensor.append( torch.full((h * w, 1), stride, dtype=dtype, device=device)) # 生成步幅張量 return torch.cat(anchor_points), torch.cat(stride_tensor) # 返回合并后的錨點(diǎn)和步幅 class customize_NMS(torch.autograd.Function): """ 繼承torch.autograd.Function 用于TensorRT的非極大值抑制(NMS)自定義函數(shù)。 """ @staticmethod def forward( ctx: Graph, boxes: Tensor, scores: Tensor, iou_threshold: float = 0.65, score_threshold: float = 0.25, max_output_boxes: int = 100, background_class: int = -1, box_coding: int = 0, plugin_version: str = '1', score_activation: int = 0 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ 正向計(jì)算NMS輸出,模擬真實(shí)的TensorRT NMS過程。 參數(shù): boxes (Tensor): 預(yù)測(cè)的邊界框。 scores (Tensor): 預(yù)測(cè)框的置信度分?jǐn)?shù)。 其他參數(shù)同樣為NMS的超參數(shù)。 返回: Tuple[Tensor, Tensor, Tensor, Tensor]: 包含檢測(cè)框數(shù)量、框坐標(biāo)、置信度分?jǐn)?shù)和類別標(biāo)簽。 """ batch_size, num_boxes, num_classes = scores.shape # 獲取批量大小、框數(shù)量和類別數(shù) num_dets = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) # 隨機(jī)生成檢測(cè)框數(shù)量(僅為模擬) boxes = torch.randn(batch_size, max_output_boxes, 4) # 隨機(jī)生成預(yù)測(cè)框 scores = torch.randn(batch_size, max_output_boxes) # 隨機(jī)生成分?jǐn)?shù) labels = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) # 隨機(jī)生成類別標(biāo)簽 return num_dets, boxes, scores, labels # 返回模擬的結(jié)果 @staticmethod def symbolic( g, boxes: Value, scores: Value, iou_threshold: float = 0.45, score_threshold: float = 0.25, max_output_boxes: int = 100, background_class: int = -1, box_coding: int = 0, score_activation: int = 0, plugin_version: str = '1') -> Tuple[Value, Value, Value, Value]: """ 計(jì)算圖的符號(hào)函數(shù),供TensorRT使用。 參數(shù): g: 計(jì)算圖對(duì)象 boxes (Value), scores (Value): 傳入的邊界框和得分 其他參數(shù)是用于配置NMS的參數(shù)。 返回: 經(jīng)過NMS處理的檢測(cè)框、得分、類別標(biāo)簽及檢測(cè)框數(shù)量。 """ out = g.op('TRT::EfficientNMS_TRT', boxes, scores, iou_threshold_f=iou_threshold, score_threshold_f=score_threshold, max_output_boxes_i=max_output_boxes, background_class_i=background_class, box_coding_i=box_coding, plugin_version_s=plugin_version, score_activation_i=score_activation, outputs=4) # 使用TensorRT的EfficientNMS插件 nums_dets, boxes, scores, classes = out # 獲取輸出的檢測(cè)框數(shù)量、框坐標(biāo)、得分和類別 return nums_dets, boxes, scores, classes # 返回結(jié)果 class Post_process_Detect(nn.Module): """ 用于后處理的檢測(cè)模塊,執(zhí)行檢測(cè)后的非極大值抑制(NMS)。 """ export = True shape = None dynamic = False iou_thres = 0.65 # 默認(rèn)的IoU閾值 conf_thres = 0.25 # 默認(rèn)的置信度閾值 topk = 100 # 輸出的最大檢測(cè)框數(shù)量 def __init__(self, *args, **kwargs): super().__init__() def forward(self, x): """ 執(zhí)行后處理操作,提取預(yù)測(cè)框、置信度和類別。 參數(shù): x (Tensor): 輸入的特征圖。 返回: Tuple[Tensor, Tensor, Tensor]: 預(yù)測(cè)框、置信度和類別。 """ shape = x[0].shape # 獲取輸入的形狀 b, res, b_reg_num = shape[0], [], self.reg_max * 4 # b為特征列表第一個(gè)元素的批量大小,表示處理的樣本數(shù)量, # res聲明一個(gè)空列表存儲(chǔ)處理過的特征圖 # b_reg_num為回歸框的數(shù)量 #遍歷特征層(self.nl表示特征層數(shù)),將每一層的框預(yù)測(cè)和分類預(yù)測(cè)拼接。 for i in range(self.nl): res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) # 特征拼接 # 調(diào)用 # make_anchors # 生成錨點(diǎn)和步幅,用于還原邊界框的絕對(duì)坐標(biāo)。 if self.dynamic or self.shape != shape: self.anchors, self.strides = (x.transpose( 0, 1) for x in gen_anchors(x, self.stride, 0.5)) # 生成錨點(diǎn)和步幅 self.shape = shape # 更新輸入的形狀 x = [i.view(b, self.no, -1) for i in res] # 調(diào)整特征圖形狀 y = torch.cat(x, 2) # 拼接所有特征圖 boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid() # 提取框和分?jǐn)?shù) boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2) # 變換框的形狀 boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes) # 對(duì)框進(jìn)行softmax處理 boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...] # 分離框的不同部分 boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1) # 合并框坐標(biāo) boxes = boxes * self.strides # 乘以步幅 return customize_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2), self.iou_thres, self.conf_thres, self.topk) # 執(zhí)行NMS def optim(module: nn.Module): setattr(module, '__class__', Post_process_Detect) for item in model.modules(): optim(item) item.to(args.device) #輸入cpu或者gpu的卡號(hào)
自定義這里是在yolo官方得到的1*4*8400和1*35*8400進(jìn)行矩陣轉(zhuǎn)換2<->3,最后引入EfficientNMS_TRT插件后處理,可以有效加速NMS處理。
2.2.3 EfficientNMS_TRT插件
EfficientNMS_TRT
是 TensorRT 中的一個(gè)高效非極大值抑制 (NMS) 插件,用于快速過濾檢測(cè)框。它通過優(yōu)化的 CUDA 實(shí)現(xiàn)來執(zhí)行 NMS 操作,特別適合于深度學(xué)習(xí)推理階段中目標(biāo)檢測(cè)任務(wù)的后處理。支持在一個(gè)批次中對(duì)多個(gè)圖像同時(shí)執(zhí)行 NMS。
輸出結(jié)果為num_dets
, detection_boxes, detection_scores, detection_classes
,分別代表經(jīng)過 NMS 篩選后保留的邊界框數(shù),每張圖片保留的檢測(cè)框的坐標(biāo),每張圖片中保留下來的檢測(cè)框的分?jǐn)?shù)(由高到低),每個(gè)保留下來的邊界框的類別索引。
三、結(jié)語
到此這篇關(guān)于YOLOv8模型pytorch格式轉(zhuǎn)為onnx格式的文章就介紹到這了,更多相關(guān)YOLOv8模型pytorch轉(zhuǎn)onnx格式內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python 中的判斷語句,循環(huán)語句,函數(shù)
這篇文章主要介紹了Python 中的判斷語句,循環(huán)語句,函數(shù),文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-08-08Python使用微信itchat接口實(shí)現(xiàn)查看自己微信的信息功能詳解
這篇文章主要介紹了Python使用微信itchat接口實(shí)現(xiàn)查看自己微信的信息功能,結(jié)合實(shí)例形式分析了Python微信itchat模塊常見功能與操作技巧,需要的朋友可以參考下2019-08-08Python腳本實(shí)現(xiàn)音頻和視頻格式轉(zhuǎn)換
這篇文章主要為大家詳細(xì)介紹了Python如何通過腳本實(shí)現(xiàn)音頻和視頻格式轉(zhuǎn)換,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2025-03-03django 2.0更新的10條注意事項(xiàng)總結(jié)
Django 是 Python Web 開發(fā)最常用的框架之一,跟進(jìn)它的最新變化絕對(duì)是必須的。下面這篇文章主要給大家介紹了關(guān)于django 2.0更新的10條注意事項(xiàng),文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考借鑒,下面來一起看看吧。2018-01-01對(duì)DataFrame數(shù)據(jù)中的重復(fù)行,利用groupby累加合并的方法詳解
今天小編就為大家分享一篇對(duì)DataFrame數(shù)據(jù)中的重復(fù)行,利用groupby累加合并的方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01Python網(wǎng)絡(luò)請(qǐng)求模塊urllib與requests使用介紹
網(wǎng)絡(luò)爬蟲的第一步就是根據(jù)URL,獲取網(wǎng)頁的HTML信息。在Python3中,可以使用urllib和requests進(jìn)行網(wǎng)頁數(shù)據(jù)獲取,這篇文章主要介紹了Python網(wǎng)絡(luò)請(qǐng)求模塊urllib與requests使用2022-10-10