PyTorch中的C++擴(kuò)展實(shí)現(xiàn)
今天要聊聊用 PyTorch 進(jìn)行 C++ 擴(kuò)展。
在正式開始前,我們需要了解 PyTorch 如何自定義module。這其中,最常見的就是在 python 中繼承torch.nn.Module,用 PyTorch 中已有的 operator 來組裝成自己的模塊。這種方式實(shí)現(xiàn)簡(jiǎn)單,但是,計(jì)算效率卻未必最佳,另外,如果我們想實(shí)現(xiàn)的功能過于復(fù)雜,可能 PyTorch 中那些已有的函數(shù)也沒法滿足我們的要求。這時(shí),用 C、C++、CUDA 來擴(kuò)展 PyTorch 的模塊就是最佳的選擇了。
由于目前市面上大部分深度學(xué)習(xí)系統(tǒng)(TensorFlow、PyTorch 等)都是基于 C、C++ 構(gòu)建的后端,因此這些系統(tǒng)基本都存在 C、C++ 的擴(kuò)展接口。PyTorch 是基于 Torch 構(gòu)建的,而 Torch 底層采用的是 C 語言,因此 PyTorch 天生就和 C 兼容,因此用 C 來擴(kuò)展 PyTorch 并非難事。而隨著 PyTorch1.0 的發(fā)布,官方已經(jīng)開始考慮將 PyTorch 的底層代碼用 caffe2 替換,因此他們也在逐步重構(gòu) ATen,后者是目前 PyTorch 使用的 C++ 擴(kuò)展庫(kù)??偟膩碚f,C++ 是未來的趨勢(shì)。至于 CUDA,這是幾乎所有深度學(xué)習(xí)系統(tǒng)在構(gòu)建之初就采用的工具,因此 CUDA 的擴(kuò)展接口是標(biāo)配。
本文用一個(gè)簡(jiǎn)單的例子,梳理一下進(jìn)行 C++ 擴(kuò)展的步驟,至于一些具體的實(shí)現(xiàn),不做深入探討。
PyTorch的C、C++、CUDA擴(kuò)展
關(guān)于 PyTorch 的 C 擴(kuò)展,可以參考官方教程或者這篇博文,其操作并不難,無非是借助原先 Torch 提供的<TH/TH.h>
和<THC/THC.h>
等接口,再利用 PyTorch 中提供的torch.util.ffi
模塊進(jìn)行擴(kuò)展。需要注意的是,隨著 PyTorch 版本升級(jí),這種做法在新版本的 PyTorch 中可能會(huì)失效。
本文主要介紹 C++(未來可能加上 CUDA)的擴(kuò)展方法。
C++擴(kuò)展
首先,介紹一下基本流程。在 PyTorch 中擴(kuò)展 C++/CUDA 主要分為幾步:
- 安裝好 pybind11 模塊(通過 pip 或者 conda 等安裝),這個(gè)模塊會(huì)負(fù)責(zé) python 和 C++ 之間的綁定;
- 用 C++ 寫好自定義層的功能,包括前向傳播forward和反向傳播backward;
- 寫好 setup.py,并用 python 提供的setuptools來編譯并加載 C++ 代碼。
- 編譯安裝,在 python 中調(diào)用 C++ 擴(kuò)展接口。
接下來,我們就用一個(gè)簡(jiǎn)單的例子(z=2x+y)來演示這幾個(gè)步驟。
第一步
安裝 pybind11 比較簡(jiǎn)單,直接略過。我們先寫好 C++ 相關(guān)的文件:
頭文件 test.h
#include <torch/extension.h> #include <vector> // 前向傳播 torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB); // 反向傳播 std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);
注意,這里引用的<torch/extension.h>頭文件至關(guān)重要,它主要包括三個(gè)重要模塊:
- pybind11,用于 C++ 和 python 交互;
- ATen,包含 Tensor 等重要的函數(shù)和類;
- 一些輔助的頭文件,用于實(shí)現(xiàn) ATen 和 pybind11 之間的交互。
源文件 test.cpp 如下:
#include "test.h" // 前向傳播,兩個(gè) Tensor 相加。這里只關(guān)注 C++ 擴(kuò)展的流程,具體實(shí)現(xiàn)不深入探討。 torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y) { AT_ASSERTM(x.sizes() == y.sizes(), "x must be the same size as y"); torch::Tensor z = torch::zeros(x.sizes()); z = 2 * x + y; return z; } // 反向傳播 // 在這個(gè)例子中,z對(duì)x的導(dǎo)數(shù)是2,z對(duì)y的導(dǎo)數(shù)是1。 // 至于這個(gè)backward函數(shù)的接口(參數(shù),返回值)為何要這樣設(shè)計(jì),后面會(huì)講。 std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput) { torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes()); torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes()); return {gradOutputX, gradOutputY}; } // pybind11 綁定 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &Test_forward_cpu, "TEST forward"); m.def("backward", &Test_backward_cpu, "TEST backward"); }
第二步
新建一個(gè)編譯安裝的配置文件 setup.py,文件目錄安排如下:
└── csrc ├── cpu │ ├── test.cpp │ └── test.h └── setup.py
以下是 setup.py 中的內(nèi)容:
from setuptools import setup import os import glob from torch.utils.cpp_extension import BuildExtension, CppExtension # 頭文件目錄 include_dirs = os.path.dirname(os.path.abspath(__file__)) # 源代碼目錄 source_cpu = glob.glob(os.path.join(include_dirs, 'cpu', '*.cpp')) setup( name='test_cpp', # 模塊名稱,需要在python中調(diào)用 version="0.1", ext_modules=[ CppExtension('test_cpp', sources=source_cpu, include_dirs=[include_dirs]), ], cmdclass={ 'build_ext': BuildExtension } )
注意,這個(gè) C++ 擴(kuò)展被命名為test_cpp,意思是說,在 python 中可以通過test_cpp模塊來調(diào)用 C++ 函數(shù)。
第三步
在 cpu 這個(gè)目錄下,執(zhí)行下面的命令編譯安裝 C++ 代碼:
python setup.py install
之后,可以看到一堆輸出,該 C++ 模塊會(huì)被安裝在 python 的 site-packages 中。
完成上面幾步后,就可以在 python 中調(diào)用 C++ 代碼了。在 PyTorch 中,按照慣例需要先把 C++ 中的前向傳播和反向傳播封裝成一個(gè)函數(shù)op(以下代碼放在 test.py 文件中):
from torch.autograd import Function import test_cpp class TestFunction(Function): @staticmethod def forward(ctx, x, y): return test_cpp.forward(x, y) @staticmethod def backward(ctx, gradOutput): gradX, gradY = test_cpp.backward(gradOutput) return gradX, gradY
這樣一來,我們相當(dāng)于把 C++ 擴(kuò)展的函數(shù)嵌入到 PyTorch 自己的框架內(nèi)。
我查看了這個(gè)Function類的代碼,發(fā)現(xiàn)是個(gè)挺有意思的東西:
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): ... @staticmethod def forward(ctx, *args, **kwargs): r"""Performs the operation. This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types). The context can be used to store tensors that can be then retrieved during the backward pass. """ raise NotImplementedError @staticmethod def backward(ctx, *grad_outputs): r"""Defines a formula for differentiating the operation. This function is to be overridden by all subclasses. It must accept a context :attr:`ctx` as the first argument, followed by as many outputs did :func:`forward` return, and it should return as many tensors, as there were inputs to :func:`forward`. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. The context can be used to retrieve tensors saved during the forward pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple of booleans representing whether each input needs gradient. E.g., :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the first input to :func:`forward` needs gradient computated w.r.t. the output. """ raise NotImplementedError
這里需要注意一下backward的實(shí)現(xiàn)規(guī)則。該接口包含兩個(gè)參數(shù):ctx是一個(gè)輔助的環(huán)境變量,grad_outputs則是來自前一層網(wǎng)絡(luò)的梯度列表,而且這個(gè)梯度列表的數(shù)量與forward函數(shù)返回的參數(shù)數(shù)量相同,這也符合鏈?zhǔn)椒▌t的原理,因?yàn)殒準(zhǔn)椒▌t就需要把前一層中所有相關(guān)的梯度與當(dāng)前層進(jìn)行相乘或相加。同時(shí),backward需要返回forward中每個(gè)輸入?yún)?shù)的梯度,如果forward中包括 n 個(gè)參數(shù),就需要一一返回 n 個(gè)梯度。所以,在上面這個(gè)例子中,我們的backward函數(shù)接收一個(gè)參數(shù)作為輸入(forward只輸出一個(gè)變量),并返回兩個(gè)梯度(forward接收上一層兩個(gè)輸入變量)。
定義完Function后,就可以在Module中使用這個(gè)自定義op了:
import torch class Test(torch.nn.Module): def __init__(self): super(Test, self).__init__() def forward(self, inputA, inputB): return TestFunction.apply(inputA, inputB)
現(xiàn)在,我們的文件目錄變成:
├── csrc │ ├── cpu │ │ ├── test.cpp │ │ └── test.h │ └── setup.py └── test.py
之后,我們就可以將 test.py 當(dāng)作一般的 PyTorch 模塊進(jìn)行調(diào)用了。
測(cè)試
下面,我們測(cè)試一下前向傳播和反向傳播:
import torch from torch.autograd import Variable from test import Test x = Variable(torch.Tensor([1,2,3]), requires_grad=True) y = Variable(torch.Tensor([4,5,6]), requires_grad=True) test = Test() z = test(x, y) z.sum().backward() print('x: ', x) print('y: ', y) print('z: ', z) print('x.grad: ', x.grad) print('y.grad: ', y.grad)
輸出如下:
x: tensor([1., 2., 3.], requires_grad=True)
y: tensor([4., 5., 6.], requires_grad=True)
z: tensor([ 6., 9., 12.], grad_fn=<TestFunctionBackward>)
x.grad: tensor([2., 2., 2.])
y.grad: tensor([1., 1., 1.])
可以看出,前向傳播滿足 z=2x+y,而反向傳播的結(jié)果也在意料之中。
CUDA擴(kuò)展
雖然 C++ 寫的代碼可以直接跑在 GPU 上,但它的性能還是比不上直接用 CUDA 編寫的代碼,畢竟 ATen 沒法并不知道如何去優(yōu)化算法的性能。不過,由于我對(duì) CUDA 仍一竅不通,因此這一步只能暫時(shí)略過,留待之后補(bǔ)充~囧~。
參考
CUSTOM C EXTENSIONS FOR PYTORCH
CUSTOM C++ AND CUDA EXTENSIONS
Pytorch拓展進(jìn)階(一):Pytorch結(jié)合C以及Cuda語言
Pytorch拓展進(jìn)階(二):Pytorch結(jié)合C++以及Cuda拓展
到此這篇關(guān)于PyTorch中的C++擴(kuò)展實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)PyTorch C++擴(kuò)展 內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python進(jìn)行ffmpeg推流和拉流rtsp、rtmp實(shí)例詳解
Python推流本質(zhì)是調(diào)用FFmpeg的推流進(jìn)程,下面這篇文章主要給大家介紹了關(guān)于Python進(jìn)行ffmpeg推流和拉流rtsp、rtmp的相關(guān)資料,需要的朋友可以參考下2023-01-01django rest framework 實(shí)現(xiàn)用戶登錄認(rèn)證詳解
這篇文章主要介紹了django rest framework 實(shí)現(xiàn)用戶登錄認(rèn)證詳解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07python測(cè)試驅(qū)動(dòng)開發(fā)實(shí)例
這篇文章主要介紹了python測(cè)試驅(qū)動(dòng)開發(fā)實(shí)例,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2014-10-10Python中Yield的基本用法及Yield與return的區(qū)別解析
Python中有一個(gè)非常有用的語法叫做生成器,用到的關(guān)鍵字就是yield,這篇文章主要介紹了Python中Yield的基本用法及Yield與return的區(qū)別,需要的朋友可以參考下2022-10-10python創(chuàng)建多個(gè)logging日志文件的方法實(shí)現(xiàn)
本文主要介紹了python創(chuàng)建多個(gè)logging日志文件的方法實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-07-07python中用cantools和can工具包解析blf文件的方法
這篇文章主要給大家介紹了關(guān)于python中用cantools和can工具包解析blf文件的相關(guān)資料,blf數(shù)據(jù)不像mf4那樣自帶信號(hào)數(shù)據(jù)庫(kù),因?yàn)樗怯浫罩居玫?一般情況下要盡可能的小,需要的朋友可以參考下2023-09-09Pythont特殊語法filter,map,reduce,apply使用方法
這篇文章主要介紹了Pythont特殊語法filter,map,reduce,apply使用方法,需要的朋友可以參考下2016-02-02