Pytorch之上/下采樣函數(shù)torch.nn.functional.interpolate插值詳解
Pytorch上/下采樣函數(shù)torch.nn.functional.interpolate插值
torch.nn.functional.interpolate(input_tensor, size=None, scale_factor=8, mode='bilinear', align_corners=False) ''' Down/up samples the input to either the given size or the given scale_factor The algorithm used for interpolation is determined by mode. Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape. The input dimensions are interpreted in the form: mini-batch x channels x [optional depth] x [optional height] x width. The modes available for resizing are: nearest, linear (3D-only), bilinear, bicubic (4D-only), trilinear (5D-only), area '''
這個(gè)函數(shù)是用來(lái)上采樣或下采樣tensor的空間維度(h,w):
input_tensor支持輸入3D (b, c, w)或(batch,seq_len,dim)、4D (b, c, h, w)、5D (b, c, f, h, w)的 tensor shape。其中b表示batch_size,c表示channel,f表示frames,h表示height,w表示weight。
size是目標(biāo)tensor的(w)/(h,w)/(f,h,w)的形狀;scale_factor是采樣tensor的saptial shape(w)/(h,w)/(f,h,w)的縮放系數(shù),size和scale_factor兩個(gè)參數(shù)只能定義一個(gè),具體是上采樣,還是下采樣根據(jù)這兩個(gè)參數(shù)判斷。如果size或者scale_factor是list序列,則必須匹配輸入的大小。
- 如果輸入3D,則它們的序列長(zhǎng)度必須是1(只縮放最后1個(gè)維度w)。
- 如果輸入4D,則它們的序列長(zhǎng)度必須是2(縮放最后2個(gè)維度h,w)。
- 如果輸入是5D,則它們的序列長(zhǎng)度必須是3(縮放最后3個(gè)維度f(wàn),h,w)。
插值算法mode可選:最近鄰(nearest, 默認(rèn))、線性(linear, 3D-only)、雙線性(bilinear, 4D-only)、三線性(trilinear, 5D-only)等等。
是否align_corners對(duì)齊角點(diǎn):可選的bool值, 如果 align_corners=True,則對(duì)齊 input 和 output 的角點(diǎn)像素(corner pixels),保持在角點(diǎn)像素的值. 只會(huì)對(duì) mode=linear, bilinear, trilinear 有作用. 默認(rèn)是 False。一圖看懂align_corners=True與False的區(qū)別,從4×4上采樣成8×8。
一個(gè)是按四角的像素點(diǎn)中心對(duì)齊,另一個(gè)是按四角的像素角點(diǎn)對(duì)齊:

import torch import torch.nn.functional as F b, c, f, h, w = 1, 3, 8, 64, 64
1. upsample/downsample 3D tensor
# interpolate 3D tensor x = torch.randn([b, c, w]) ## downsample to (b, c, w/2) y0 = F.interpolate(x, scale_factor=0.5, mode='nearest') y1 = F.interpolate(x, size=[w//2], mode='nearest') y2 = F.interpolate(x, scale_factor=0.5, mode='linear') # only 3D y3 = F.interpolate(x, size=[w//2], mode='linear') # only 3D print(y0.shape, y1.shape, y2.shape, y3.shape) # torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) ## upsample to (b, c, w*2) y0 = F.interpolate(x, scale_factor=2, mode='nearest') y1 = F.interpolate(x, size=[w*2], mode='nearest') y2 = F.interpolate(x, scale_factor=2, mode='linear') # only 3D y3 = F.interpolate(x, size=[w*2], mode='linear') # only 3D print(y0.shape, y1.shape, y2.shape, y3.shape) # torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128])
2. upsample/downsample 4D tensor
# interpolate 4D tensor x = torch.randn(b, c, h, w) ## downsample to (b, c, h/2, w/2) y0 = F.interpolate(x, scale_factor=0.5, mode='nearest') y1 = F.interpolate(x, size=[h//2, w//2], mode='nearest') y2 = F.interpolate(x, scale_factor=0.5, mode='bilinear') # only 4D y3 = F.interpolate(x, size=[h//2, w//2], mode='bilinear') # only 4D print(y0.shape, y1.shape, y2.shape, y3.shape) # torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) ## upsample to (b, c, h*2, w*2) y0 = F.interpolate(x, scale_factor=2, mode='nearest') y1 = F.interpolate(x, size=[h*2, w*2], mode='nearest') y2 = F.interpolate(x, scale_factor=2, mode='bilinear') # only 4D y3 = F.interpolate(x, size=[h*2, w*2], mode='bilinear') # only 4D print(y0.shape, y1.shape, y2.shape, y3.shape) # torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128])
3. upsample/downsample 5D tensor
# interpolate 5D tensor x = torch.randn(b, c, f, h, w) ## downsample to (b, c, f/2, h/2, w/2) y0 = F.interpolate(x, scale_factor=0.5, mode='nearest') y1 = F.interpolate(x, size=[f//2, h//2, w//2], mode='nearest') y2 = F.interpolate(x, scale_factor=2, mode='trilinear') # only 5D y3 = F.interpolate(x, size=[f//2, h//2, w//2], mode='trilinear') # only 5D print(y0.shape, y1.shape, y2.shape, y3.shape) # torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 4, 32, 32]) ## upsample to (b, c, f*2, h*2, w*2) y0 = F.interpolate(x, scale_factor=2, mode='nearest') y1 = F.interpolate(x, size=[f*2, h*2, w*2], mode='nearest') y2 = F.interpolate(x, scale_factor=2, mode='trilinear') # only 5D y3 = F.interpolate(x, size=[f*2, h*2, w*2], mode='trilinear') # only 5D print(y0.shape, y1.shape, y2.shape, y3.shape) # torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128])
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python 通過(guò)監(jiān)聽(tīng)端口實(shí)現(xiàn)唯一腳本運(yùn)行方式
這篇文章主要介紹了Python 通過(guò)監(jiān)聽(tīng)端口實(shí)現(xiàn)唯一腳本運(yùn)行方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05
Python通過(guò)Pillow實(shí)現(xiàn)圖片對(duì)比
這篇文章主要介紹了Python Pillow實(shí)現(xiàn)圖片對(duì)比,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04
pandas 如何保存數(shù)據(jù)到excel,csv
這篇文章主要介紹了pandas 如何保存數(shù)據(jù)到excel,csv的實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-07-07
Gauss-Seidel迭代算法的Python實(shí)現(xiàn)詳解
這篇文章主要介紹了Gauss-Seidel迭代算法的Python實(shí)現(xiàn)詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-06-06
一文帶你了解Python中Scikit-learn庫(kù)的使用
Scikit-learn是Python的一個(gè)開(kāi)源機(jī)器學(xué)習(xí)庫(kù),它支持監(jiān)督和無(wú)監(jiān)督學(xué)習(xí),本文主要來(lái)深入探討一下Scikit-learn的更高級(jí)的特性,感興趣的小伙伴可以了解下2023-07-07
python初學(xué)之用戶登錄的實(shí)現(xiàn)過(guò)程(實(shí)例講解)
下面小編就為大家分享一篇python初學(xué)之用戶登錄的實(shí)現(xiàn)過(guò)程(實(shí)例講解),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2017-12-12
Python實(shí)現(xiàn)的簡(jiǎn)單讀寫(xiě)csv文件操作示例
這篇文章主要介紹了Python實(shí)現(xiàn)的簡(jiǎn)單讀寫(xiě)csv文件操作,結(jié)合實(shí)例形式分析了Python使用csv模塊針對(duì)csv文件進(jìn)行讀寫(xiě)操作的相關(guān)實(shí)現(xiàn)技巧與注意事項(xiàng),需要的朋友可以參考下2018-07-07

